diff --git a/catfish/utils/processes.py b/catfish/utils/processes.py index 6822dc4..a1382a9 100644 --- a/catfish/utils/processes.py +++ b/catfish/utils/processes.py @@ -2,6 +2,8 @@ from typing import List import psutil +CURRENT_PROCESS = psutil.Process() + def terminate_processes(procs: List[psutil.Process], timeout=3): # https://psutil.readthedocs.io/en/latest/#terminate-my-children @@ -14,3 +16,11 @@ def terminate_processes(procs: List[psutil.Process], timeout=3): for p in alive: p.kill() gone, alive = psutil.wait_procs(alive, timeout=timeout) + return alive + + +def get_root_process() -> psutil.Process: + proc = CURRENT_PROCESS + while proc.parent(): + proc = proc.parent() + return proc diff --git a/catfish/worker/router.py b/catfish/worker/router.py index 6d35000..92d15da 100644 --- a/catfish/worker/router.py +++ b/catfish/worker/router.py @@ -5,9 +5,11 @@ from aiohttp import web async def handle_request(request): return web.json_response({}) + def get_server(): return web.Server(handle_request) + async def run_server(loop, port): click.echo("Starting server...") aiohttp_server = get_server() diff --git a/tests/__init__.py b/tests/__init__.py index 34f552c..27db563 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,6 +3,7 @@ from aiohttp.test_utils import unused_port from click.testing import CliRunner from catfish.__main__ import cli from catfish import worker +from catfish.utils.processes import terminate_processes, CURRENT_PROCESS import functools from multiprocessing import Process @@ -16,13 +17,16 @@ class BaseTestCase(AsyncTestCase): def tearDown(self): worker.stop_worker() + terminate_processes(CURRENT_PROCESS.children(recursive=True)) class BaseWorkerTestCase(BaseTestCase): def setUp(self): super().setUp() self.unused_port = unused_port() - self.router_process = Process(target=worker.run, args=(self.unused_port,), daemon=True) + self.router_process = Process( + target=worker.run, args=(self.unused_port,), daemon=True + ) self.router_process.start() def tearDown(self): diff --git a/tests/test_utils/test_processes.py b/tests/test_utils/test_processes.py new file mode 100644 index 0000000..2685780 --- /dev/null +++ b/tests/test_utils/test_processes.py @@ -0,0 +1,35 @@ +import os +import subprocess + +import psutil + +from catfish.utils import processes +from tests import BaseTestCase + + +class RootProcessTestCase(BaseTestCase): + def test_root_processes(self): + root_process = processes.get_root_process() + self.assertIsNone(root_process.parent()) + all_children = [proc.pid for proc in root_process.children(recursive=True)] + self.assertIn(os.getpid(), all_children) + + def test_current_process(self): + self.assertEqual(processes.CURRENT_PROCESS.pid, os.getpid()) + + +class TerminateProcessesTestCase(BaseTestCase): + def test_kills_lots_of_processes(self): + created_processes = [] + for _ in range(10): + created_processes.append(subprocess.Popen("yes", stdout=subprocess.PIPE)) + + for proc in created_processes: + self.assertIsNone(proc.poll()) + + still_alive = processes.terminate_processes( + [psutil.Process(proc.pid) for proc in created_processes] + ) + self.assertIsNone(still_alive) + for proc in created_processes: + self.assertEqual(proc.poll(), 0)