From fffcbf8f484618b4b616c6c3b715bf32e56dd80e Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Wed, 19 Dec 2018 15:44:03 +0000 Subject: [PATCH] Allow processes to be restarted and terminated --- catfish/worker/server.py | 29 +++++++++++++++---------- tests/test_worker/test_server.py | 36 +++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 6178c35..13271d5 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -2,6 +2,7 @@ import asyncio import atexit import os import shlex +import signal import socket import subprocess from enum import Enum, auto @@ -46,20 +47,26 @@ def write_data(writer, data): writer.write(ujson.dumps(data).encode()) -async def publish_stdout_for(process, ctf_process: Process): +async def publish_stdout_for(process, ctf_process: Process, project: Project): ctx = zmq.Context() sock = ctx.socket(zmq.PUB) socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket)) sock.bind("ipc://" + socket_path) - while True: - output = await process.stdout.readline() - if not output: - break - sock.send_string(output.decode()) - process.kill() - sock.close() - ctx.destroy() - await aio.remove_file(socket_path) + try: + while True: + while True: + output = await process.stdout.readline() + if not output: + break + sock.send_string(output.decode()) + await process.wait() + exit_code = process.returncode + if exit_code == -signal.SIGHUP: + process = await start_process(project, ctf_process) + finally: + sock.close() + ctx.destroy() + await aio.remove_file(socket_path) async def start_process(project: Project, process: Process): @@ -80,7 +87,7 @@ async def start_process(project: Project, process: Process): async def run_process_command(project: Project, process: Process): proc = await start_process(project, process) - asyncio.ensure_future(publish_stdout_for(proc, process)) + asyncio.ensure_future(publish_stdout_for(proc, process, project)) return proc diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index f566204..1c51472 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,7 +1,15 @@ +import signal +import time + import psutil from catfish.project import Project -from catfish.utils.processes import find_running_process_for, is_process_running +from catfish.utils.processes import ( + find_running_process_for, + is_process_running, + terminate_processes, + wait_for_process_terminate, +) from catfish.worker import BASE_SOCKET_DIR from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from tests import BaseWorkerTestCase @@ -47,6 +55,32 @@ class ProcessWorkerTestCase(BaseWorkerTestCase): for path in self.project.get_extra_path(): self.assertIn(str(path), path_dirs) + def test_process_restart(self): + response = send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + initial_pid = response["pid"] + psutil.Process(initial_pid).send_signal(signal.SIGHUP) + wait_for_process_terminate(initial_pid) + time.sleep(2) + new_process = find_running_process_for(self.process) + self.assertNotEqual(new_process.pid, initial_pid) + self.assertFalse(is_process_running(initial_pid)) + self.assertTrue(is_process_running(new_process.pid)) + + def test_process_terminate(self): + response = send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + initial_pid = response["pid"] + terminate_processes([psutil.Process(initial_pid)]) + wait_for_process_terminate(initial_pid) + time.sleep(2) + self.assertIsNone(find_running_process_for(self.process)) + self.assertFalse(is_process_running(initial_pid)) + class ProcessLogsTestCase(BaseWorkerTestCase): def setUp(self):