From 1e190c79322bb93f1ab92c3b6cf7ff4126e24fab Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Wed, 19 Dec 2018 13:43:28 +0000 Subject: [PATCH] Better handle stopping router in tests --- catfish/utils/processes.py | 15 +++++++++++++-- catfish/worker/server.py | 19 ++++++++++++++----- tests/__init__.py | 1 + tests/test_worker/test_server.py | 3 ++- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/catfish/utils/processes.py b/catfish/utils/processes.py index e2e459d..34eadd0 100644 --- a/catfish/utils/processes.py +++ b/catfish/utils/processes.py @@ -3,6 +3,8 @@ from typing import List import psutil +from catfish.project import Process + CURRENT_PROCESS = psutil.Process() @@ -20,8 +22,8 @@ def terminate_processes(procs: List[psutil.Process], timeout=3): return alive -def terminate_subprocesses(): - return terminate_processes(CURRENT_PROCESS.children(recursive=True)) +def terminate_subprocesses(parent: psutil.Process = CURRENT_PROCESS): + terminate_processes(parent.children(recursive=True)) def get_root_process() -> psutil.Process: @@ -47,3 +49,12 @@ def wait_for_process_start(pid: int): def wait_for_process_terminate(pid: int): while is_process_running(pid): time.sleep(0.1) + + +def find_running_process_for(process: Process): + for proc in get_root_process().children(recursive=True): + try: + if proc.environ().get("CATFISH_IDENT") == process.ident: + return proc + except psutil.AccessDenied: + continue diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 615c94f..6178c35 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -12,7 +12,7 @@ import zmq import ujson from catfish.project import Process, Project from catfish.utils import aio -from catfish.utils.processes import terminate_processes +from catfish.utils.processes import CURRENT_PROCESS, terminate_subprocesses from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock") @@ -62,15 +62,24 @@ async def publish_stdout_for(process, ctf_process: Process): await aio.remove_file(socket_path) -async def run_process_command(project: Project, process: Process): +async def start_process(project: Project, process: Process): command = shlex.split(process.command) - proc = await asyncio.create_subprocess_exec( + return await asyncio.create_subprocess_exec( *command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident}, + env={ + **os.environ, + **project.get_environment(), + "CATFISH_IDENT": process.ident, + "CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid), + }, cwd=project.root ) + + +async def run_process_command(project: Project, process: Process): + proc = await start_process(project, process) asyncio.ensure_future(publish_stdout_for(proc, process)) return proc @@ -96,7 +105,7 @@ async def client_connected(reader, writer): async def start_server(): - atexit.register(terminate_processes) + atexit.register(terminate_subprocesses) server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET) click.echo("Started server") await server.serve_forever() diff --git a/tests/__init__.py b/tests/__init__.py index a1cc234..74cb424 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -67,6 +67,7 @@ class BaseWorkerTestCase(BaseTestCase): time.sleep(0.1) def tearDown(self): + terminate_subprocesses(psutil.Process(self.worker_process.pid)) self.worker_process.terminate() self.worker_process.wait() super().tearDown() diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index 5ae84e8..f566204 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,7 +1,7 @@ import psutil from catfish.project import Project -from catfish.utils.processes import is_process_running +from catfish.utils.processes import find_running_process_for, is_process_running from catfish.worker import BASE_SOCKET_DIR from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from tests import BaseWorkerTestCase @@ -25,6 +25,7 @@ class ProcessWorkerTestCase(BaseWorkerTestCase): {"path": str(self.project.root), "process": str(self.process.name)}, ) self.assertTrue(is_process_running(response["pid"])) + self.assertEqual(find_running_process_for(self.process).pid, response["pid"]) def test_additional_environment(self): response = send_to_server(