add utils to get hostname from request and process from hostname

This commit is contained in:
Jake Howard 2018-12-12 23:15:02 +00:00
parent 9168ac32de
commit aff11d7d84
Signed by: jake
GPG key ID: 57AFB45680EDD477
3 changed files with 67 additions and 1 deletions

View file

@ -1,9 +1,11 @@
import click import click
from aiohttp import web from aiohttp import web
from .utils import get_hostname_from_request
async def handle_request(request): async def handle_request(request):
return web.json_response({}) return web.json_response({"host": get_hostname_from_request(request)})
def get_server(): def get_server():

24
catfish/router/utils.py Normal file
View file

@ -0,0 +1,24 @@
from typing import Optional
import psutil
from catfish.utils.processes import get_root_process
ROOT_PROCESS = get_root_process()
HOSTNAME_ENV_VAR = "VIRTUAL_HOST"
def get_process_for_hostname(hostname: str) -> Optional[psutil.Process]:
for process in ROOT_PROCESS.children(recursive=True):
try:
if hostname in process.environ().get(HOSTNAME_ENV_VAR, "").split(","):
return process
except psutil.AccessDenied:
continue
return None
def get_hostname_from_request(request) -> str:
if getattr(request, "url", None) is not None:
return request.url.host
return request.host.split(":")[0]

View file

@ -0,0 +1,40 @@
import os
import subprocess
from aiohttp.test_utils import make_mocked_request
from catfish.router import utils
from tests import BaseTestCase
class ProcessForHostnameTestCase(BaseTestCase):
def test_no_matching_processes(self):
self.assertIsNone(utils.get_process_for_hostname("localhost"))
def test_finds_process(self):
proc = subprocess.Popen(
"yes",
stdout=subprocess.PIPE,
env={**os.environ, utils.HOSTNAME_ENV_VAR: "localhost"},
)
self.assertIsNone(proc.poll())
self.assertEqual(utils.get_process_for_hostname("localhost").pid, proc.pid)
class HostnameFromRequestTestCase(BaseTestCase):
def get_request_for_hostname(self, hostname):
return make_mocked_request("GET", "/", headers={"HOST": hostname})
def test_hostname(self):
self.assertEqual(
utils.get_hostname_from_request(self.get_request_for_hostname("localhost")),
"localhost",
)
def test_hostname_with_port(self):
self.assertEqual(
utils.get_hostname_from_request(
self.get_request_for_hostname("localhost:8080")
),
"localhost",
)