Allow processes to be restarted and terminated

This commit is contained in:
Jake Howard 2018-12-19 15:44:03 +00:00
parent 1e190c7932
commit fffcbf8f48
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 53 additions and 12 deletions

View file

@ -2,6 +2,7 @@ import asyncio
import atexit import atexit
import os import os
import shlex import shlex
import signal
import socket import socket
import subprocess import subprocess
from enum import Enum, auto from enum import Enum, auto
@ -46,17 +47,23 @@ def write_data(writer, data):
writer.write(ujson.dumps(data).encode()) writer.write(ujson.dumps(data).encode())
async def publish_stdout_for(process, ctf_process: Process): async def publish_stdout_for(process, ctf_process: Process, project: Project):
ctx = zmq.Context() ctx = zmq.Context()
sock = ctx.socket(zmq.PUB) sock = ctx.socket(zmq.PUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket)) socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
sock.bind("ipc://" + socket_path) sock.bind("ipc://" + socket_path)
try:
while True:
while True: while True:
output = await process.stdout.readline() output = await process.stdout.readline()
if not output: if not output:
break break
sock.send_string(output.decode()) sock.send_string(output.decode())
process.kill() await process.wait()
exit_code = process.returncode
if exit_code == -signal.SIGHUP:
process = await start_process(project, ctf_process)
finally:
sock.close() sock.close()
ctx.destroy() ctx.destroy()
await aio.remove_file(socket_path) await aio.remove_file(socket_path)
@ -80,7 +87,7 @@ async def start_process(project: Project, process: Process):
async def run_process_command(project: Project, process: Process): async def run_process_command(project: Project, process: Process):
proc = await start_process(project, process) proc = await start_process(project, process)
asyncio.ensure_future(publish_stdout_for(proc, process)) asyncio.ensure_future(publish_stdout_for(proc, process, project))
return proc return proc

View file

@ -1,7 +1,15 @@
import signal
import time
import psutil import psutil
from catfish.project import Project from catfish.project import Project
from catfish.utils.processes import find_running_process_for, is_process_running from catfish.utils.processes import (
find_running_process_for,
is_process_running,
terminate_processes,
wait_for_process_terminate,
)
from catfish.worker import BASE_SOCKET_DIR from catfish.worker import BASE_SOCKET_DIR
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
from tests import BaseWorkerTestCase from tests import BaseWorkerTestCase
@ -47,6 +55,32 @@ class ProcessWorkerTestCase(BaseWorkerTestCase):
for path in self.project.get_extra_path(): for path in self.project.get_extra_path():
self.assertIn(str(path), path_dirs) self.assertIn(str(path), path_dirs)
def test_process_restart(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = find_running_process_for(self.process)
self.assertNotEqual(new_process.pid, initial_pid)
self.assertFalse(is_process_running(initial_pid))
self.assertTrue(is_process_running(new_process.pid))
def test_process_terminate(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
terminate_processes([psutil.Process(initial_pid)])
wait_for_process_terminate(initial_pid)
time.sleep(2)
self.assertIsNone(find_running_process_for(self.process))
self.assertFalse(is_process_running(initial_pid))
class ProcessLogsTestCase(BaseWorkerTestCase): class ProcessLogsTestCase(BaseWorkerTestCase):
def setUp(self): def setUp(self):