Better handle stopping router in tests

This commit is contained in:
Jake Howard 2018-12-19 13:43:28 +00:00
parent f6c8d5b19d
commit 1e190c7932
Signed by: jake
GPG key ID: 57AFB45680EDD477
4 changed files with 30 additions and 8 deletions

View file

@ -3,6 +3,8 @@ from typing import List
import psutil
from catfish.project import Process
CURRENT_PROCESS = psutil.Process()
@ -20,8 +22,8 @@ def terminate_processes(procs: List[psutil.Process], timeout=3):
return alive
def terminate_subprocesses():
return terminate_processes(CURRENT_PROCESS.children(recursive=True))
def terminate_subprocesses(parent: psutil.Process = CURRENT_PROCESS):
terminate_processes(parent.children(recursive=True))
def get_root_process() -> psutil.Process:
@ -47,3 +49,12 @@ def wait_for_process_start(pid: int):
def wait_for_process_terminate(pid: int):
while is_process_running(pid):
time.sleep(0.1)
def find_running_process_for(process: Process):
for proc in get_root_process().children(recursive=True):
try:
if proc.environ().get("CATFISH_IDENT") == process.ident:
return proc
except psutil.AccessDenied:
continue

View file

@ -12,7 +12,7 @@ import zmq
import ujson
from catfish.project import Process, Project
from catfish.utils import aio
from catfish.utils.processes import terminate_processes
from catfish.utils.processes import CURRENT_PROCESS, terminate_subprocesses
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
@ -62,15 +62,24 @@ async def publish_stdout_for(process, ctf_process: Process):
await aio.remove_file(socket_path)
async def run_process_command(project: Project, process: Process):
async def start_process(project: Project, process: Process):
command = shlex.split(process.command)
proc = await asyncio.create_subprocess_exec(
return await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident},
env={
**os.environ,
**project.get_environment(),
"CATFISH_IDENT": process.ident,
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
},
cwd=project.root
)
async def run_process_command(project: Project, process: Process):
proc = await start_process(project, process)
asyncio.ensure_future(publish_stdout_for(proc, process))
return proc
@ -96,7 +105,7 @@ async def client_connected(reader, writer):
async def start_server():
atexit.register(terminate_processes)
atexit.register(terminate_subprocesses)
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
click.echo("Started server")
await server.serve_forever()

View file

@ -67,6 +67,7 @@ class BaseWorkerTestCase(BaseTestCase):
time.sleep(0.1)
def tearDown(self):
terminate_subprocesses(psutil.Process(self.worker_process.pid))
self.worker_process.terminate()
self.worker_process.wait()
super().tearDown()

View file

@ -1,7 +1,7 @@
import psutil
from catfish.project import Project
from catfish.utils.processes import is_process_running
from catfish.utils.processes import find_running_process_for, is_process_running
from catfish.worker import BASE_SOCKET_DIR
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
from tests import BaseWorkerTestCase
@ -25,6 +25,7 @@ class ProcessWorkerTestCase(BaseWorkerTestCase):
{"path": str(self.project.root), "process": str(self.process.name)},
)
self.assertTrue(is_process_running(response["pid"]))
self.assertEqual(find_running_process_for(self.process).pid, response["pid"])
def test_additional_environment(self):
response = send_to_server(