diff --git a/catfish/router/__init__.py b/catfish/router/__init__.py index 92d15da..a76a605 100644 --- a/catfish/router/__init__.py +++ b/catfish/router/__init__.py @@ -1,9 +1,11 @@ import click from aiohttp import web +from .utils import get_hostname_from_request + async def handle_request(request): - return web.json_response({}) + return web.json_response({"host": get_hostname_from_request(request)}) def get_server(): diff --git a/catfish/router/utils.py b/catfish/router/utils.py new file mode 100644 index 0000000..5232098 --- /dev/null +++ b/catfish/router/utils.py @@ -0,0 +1,24 @@ +from typing import Optional + +import psutil + +from catfish.utils.processes import get_root_process + +ROOT_PROCESS = get_root_process() +HOSTNAME_ENV_VAR = "VIRTUAL_HOST" + + +def get_process_for_hostname(hostname: str) -> Optional[psutil.Process]: + for process in ROOT_PROCESS.children(recursive=True): + try: + if hostname in process.environ().get(HOSTNAME_ENV_VAR, "").split(","): + return process + except psutil.AccessDenied: + continue + return None + + +def get_hostname_from_request(request) -> str: + if getattr(request, "url", None) is not None: + return request.url.host + return request.host.split(":")[0] diff --git a/tests/test_router/test_utils.py b/tests/test_router/test_utils.py new file mode 100644 index 0000000..457b812 --- /dev/null +++ b/tests/test_router/test_utils.py @@ -0,0 +1,40 @@ +import os +import subprocess + +from aiohttp.test_utils import make_mocked_request + +from catfish.router import utils +from tests import BaseTestCase + + +class ProcessForHostnameTestCase(BaseTestCase): + def test_no_matching_processes(self): + self.assertIsNone(utils.get_process_for_hostname("localhost")) + + def test_finds_process(self): + proc = subprocess.Popen( + "yes", + stdout=subprocess.PIPE, + env={**os.environ, utils.HOSTNAME_ENV_VAR: "localhost"}, + ) + self.assertIsNone(proc.poll()) + self.assertEqual(utils.get_process_for_hostname("localhost").pid, proc.pid) + + +class HostnameFromRequestTestCase(BaseTestCase): + def get_request_for_hostname(self, hostname): + return make_mocked_request("GET", "/", headers={"HOST": hostname}) + + def test_hostname(self): + self.assertEqual( + utils.get_hostname_from_request(self.get_request_for_hostname("localhost")), + "localhost", + ) + + def test_hostname_with_port(self): + self.assertEqual( + utils.get_hostname_from_request( + self.get_request_for_hostname("localhost:8080") + ), + "localhost", + )