Allow processes to be restarted and terminated
This commit is contained in:
parent
1e190c7932
commit
fffcbf8f48
2 changed files with 53 additions and 12 deletions
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
@ -46,20 +47,26 @@ def write_data(writer, data):
|
||||||
writer.write(ujson.dumps(data).encode())
|
writer.write(ujson.dumps(data).encode())
|
||||||
|
|
||||||
|
|
||||||
async def publish_stdout_for(process, ctf_process: Process):
|
async def publish_stdout_for(process, ctf_process: Process, project: Project):
|
||||||
ctx = zmq.Context()
|
ctx = zmq.Context()
|
||||||
sock = ctx.socket(zmq.PUB)
|
sock = ctx.socket(zmq.PUB)
|
||||||
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
||||||
sock.bind("ipc://" + socket_path)
|
sock.bind("ipc://" + socket_path)
|
||||||
while True:
|
try:
|
||||||
output = await process.stdout.readline()
|
while True:
|
||||||
if not output:
|
while True:
|
||||||
break
|
output = await process.stdout.readline()
|
||||||
sock.send_string(output.decode())
|
if not output:
|
||||||
process.kill()
|
break
|
||||||
sock.close()
|
sock.send_string(output.decode())
|
||||||
ctx.destroy()
|
await process.wait()
|
||||||
await aio.remove_file(socket_path)
|
exit_code = process.returncode
|
||||||
|
if exit_code == -signal.SIGHUP:
|
||||||
|
process = await start_process(project, ctf_process)
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.destroy()
|
||||||
|
await aio.remove_file(socket_path)
|
||||||
|
|
||||||
|
|
||||||
async def start_process(project: Project, process: Process):
|
async def start_process(project: Project, process: Process):
|
||||||
|
@ -80,7 +87,7 @@ async def start_process(project: Project, process: Process):
|
||||||
|
|
||||||
async def run_process_command(project: Project, process: Process):
|
async def run_process_command(project: Project, process: Process):
|
||||||
proc = await start_process(project, process)
|
proc = await start_process(project, process)
|
||||||
asyncio.ensure_future(publish_stdout_for(proc, process))
|
asyncio.ensure_future(publish_stdout_for(proc, process, project))
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from catfish.project import Project
|
from catfish.project import Project
|
||||||
from catfish.utils.processes import find_running_process_for, is_process_running
|
from catfish.utils.processes import (
|
||||||
|
find_running_process_for,
|
||||||
|
is_process_running,
|
||||||
|
terminate_processes,
|
||||||
|
wait_for_process_terminate,
|
||||||
|
)
|
||||||
from catfish.worker import BASE_SOCKET_DIR
|
from catfish.worker import BASE_SOCKET_DIR
|
||||||
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
||||||
from tests import BaseWorkerTestCase
|
from tests import BaseWorkerTestCase
|
||||||
|
@ -47,6 +55,32 @@ class ProcessWorkerTestCase(BaseWorkerTestCase):
|
||||||
for path in self.project.get_extra_path():
|
for path in self.project.get_extra_path():
|
||||||
self.assertIn(str(path), path_dirs)
|
self.assertIn(str(path), path_dirs)
|
||||||
|
|
||||||
|
def test_process_restart(self):
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
initial_pid = response["pid"]
|
||||||
|
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
|
||||||
|
wait_for_process_terminate(initial_pid)
|
||||||
|
time.sleep(2)
|
||||||
|
new_process = find_running_process_for(self.process)
|
||||||
|
self.assertNotEqual(new_process.pid, initial_pid)
|
||||||
|
self.assertFalse(is_process_running(initial_pid))
|
||||||
|
self.assertTrue(is_process_running(new_process.pid))
|
||||||
|
|
||||||
|
def test_process_terminate(self):
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
initial_pid = response["pid"]
|
||||||
|
terminate_processes([psutil.Process(initial_pid)])
|
||||||
|
wait_for_process_terminate(initial_pid)
|
||||||
|
time.sleep(2)
|
||||||
|
self.assertIsNone(find_running_process_for(self.process))
|
||||||
|
self.assertFalse(is_process_running(initial_pid))
|
||||||
|
|
||||||
|
|
||||||
class ProcessLogsTestCase(BaseWorkerTestCase):
|
class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
Reference in a new issue