diff --git a/catfish/utils/processes.py b/catfish/utils/processes.py index fbe83cd..e2e459d 100644 --- a/catfish/utils/processes.py +++ b/catfish/utils/processes.py @@ -20,6 +20,10 @@ def terminate_processes(procs: List[psutil.Process], timeout=3): return alive +def terminate_subprocesses(): + return terminate_processes(CURRENT_PROCESS.children(recursive=True)) + + def get_root_process() -> psutil.Process: proc = CURRENT_PROCESS while proc.parent(): diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 8d7c776..2dcd6fa 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -1,15 +1,16 @@ import asyncio +import atexit import os import shlex import socket import subprocess -import time import click import zmq import ujson from catfish.utils import aio +from catfish.utils.processes import terminate_processes from catfish.utils.sockets import ( BASE_SOCKET_DIR, NEW_LINE, @@ -78,6 +79,7 @@ async def client_connected(reader, writer): async def start_server(): + atexit.register(terminate_processes) server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET) click.echo("Started server") await server.serve_forever() diff --git a/tests/__init__.py b/tests/__init__.py index 2176dae..3116ecb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -8,7 +8,7 @@ from aiounittest import AsyncTestCase from click.testing import CliRunner from catfish.__main__ import cli -from catfish.utils.processes import CURRENT_PROCESS, terminate_processes +from catfish.utils.processes import terminate_subprocesses from catfish.utils.sockets import create_base_socket_dir, delete_base_socket_dir @@ -22,7 +22,7 @@ class BaseTestCase(AsyncTestCase): self.run_cli = functools.partial(self.cli_runner.invoke, self.cli) def tearDown(self): - terminate_processes(CURRENT_PROCESS.children(recursive=True)) + terminate_subprocesses() delete_base_socket_dir()