add utils to get hostname from request and process from hostname
This commit is contained in:
parent
9168ac32de
commit
aff11d7d84
3 changed files with 67 additions and 1 deletions
|
@ -1,9 +1,11 @@
|
|||
import click
|
||||
from aiohttp import web
|
||||
|
||||
from .utils import get_hostname_from_request
|
||||
|
||||
|
||||
async def handle_request(request):
|
||||
return web.json_response({})
|
||||
return web.json_response({"host": get_hostname_from_request(request)})
|
||||
|
||||
|
||||
def get_server():
|
||||
|
|
24
catfish/router/utils.py
Normal file
24
catfish/router/utils.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
|
||||
from catfish.utils.processes import get_root_process
|
||||
|
||||
ROOT_PROCESS = get_root_process()
|
||||
HOSTNAME_ENV_VAR = "VIRTUAL_HOST"
|
||||
|
||||
|
||||
def get_process_for_hostname(hostname: str) -> Optional[psutil.Process]:
|
||||
for process in ROOT_PROCESS.children(recursive=True):
|
||||
try:
|
||||
if hostname in process.environ().get(HOSTNAME_ENV_VAR, "").split(","):
|
||||
return process
|
||||
except psutil.AccessDenied:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def get_hostname_from_request(request) -> str:
|
||||
if getattr(request, "url", None) is not None:
|
||||
return request.url.host
|
||||
return request.host.split(":")[0]
|
40
tests/test_router/test_utils.py
Normal file
40
tests/test_router/test_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
|
||||
from catfish.router import utils
|
||||
from tests import BaseTestCase
|
||||
|
||||
|
||||
class ProcessForHostnameTestCase(BaseTestCase):
|
||||
def test_no_matching_processes(self):
|
||||
self.assertIsNone(utils.get_process_for_hostname("localhost"))
|
||||
|
||||
def test_finds_process(self):
|
||||
proc = subprocess.Popen(
|
||||
"yes",
|
||||
stdout=subprocess.PIPE,
|
||||
env={**os.environ, utils.HOSTNAME_ENV_VAR: "localhost"},
|
||||
)
|
||||
self.assertIsNone(proc.poll())
|
||||
self.assertEqual(utils.get_process_for_hostname("localhost").pid, proc.pid)
|
||||
|
||||
|
||||
class HostnameFromRequestTestCase(BaseTestCase):
|
||||
def get_request_for_hostname(self, hostname):
|
||||
return make_mocked_request("GET", "/", headers={"HOST": hostname})
|
||||
|
||||
def test_hostname(self):
|
||||
self.assertEqual(
|
||||
utils.get_hostname_from_request(self.get_request_for_hostname("localhost")),
|
||||
"localhost",
|
||||
)
|
||||
|
||||
def test_hostname_with_port(self):
|
||||
self.assertEqual(
|
||||
utils.get_hostname_from_request(
|
||||
self.get_request_for_hostname("localhost:8080")
|
||||
),
|
||||
"localhost",
|
||||
)
|
Reference in a new issue