Better handle stopping router in tests
This commit is contained in:
parent
f6c8d5b19d
commit
1e190c7932
4 changed files with 30 additions and 8 deletions
|
@ -3,6 +3,8 @@ from typing import List
|
|||
|
||||
import psutil
|
||||
|
||||
from catfish.project import Process
|
||||
|
||||
CURRENT_PROCESS = psutil.Process()
|
||||
|
||||
|
||||
|
@ -20,8 +22,8 @@ def terminate_processes(procs: List[psutil.Process], timeout=3):
|
|||
return alive
|
||||
|
||||
|
||||
def terminate_subprocesses():
|
||||
return terminate_processes(CURRENT_PROCESS.children(recursive=True))
|
||||
def terminate_subprocesses(parent: psutil.Process = CURRENT_PROCESS):
|
||||
terminate_processes(parent.children(recursive=True))
|
||||
|
||||
|
||||
def get_root_process() -> psutil.Process:
|
||||
|
@ -47,3 +49,12 @@ def wait_for_process_start(pid: int):
|
|||
def wait_for_process_terminate(pid: int):
|
||||
while is_process_running(pid):
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def find_running_process_for(process: Process):
|
||||
for proc in get_root_process().children(recursive=True):
|
||||
try:
|
||||
if proc.environ().get("CATFISH_IDENT") == process.ident:
|
||||
return proc
|
||||
except psutil.AccessDenied:
|
||||
continue
|
||||
|
|
|
@ -12,7 +12,7 @@ import zmq
|
|||
import ujson
|
||||
from catfish.project import Process, Project
|
||||
from catfish.utils import aio
|
||||
from catfish.utils.processes import terminate_processes
|
||||
from catfish.utils.processes import CURRENT_PROCESS, terminate_subprocesses
|
||||
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
||||
|
||||
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
||||
|
@ -62,15 +62,24 @@ async def publish_stdout_for(process, ctf_process: Process):
|
|||
await aio.remove_file(socket_path)
|
||||
|
||||
|
||||
async def run_process_command(project: Project, process: Process):
|
||||
async def start_process(project: Project, process: Process):
|
||||
command = shlex.split(process.command)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
return await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident},
|
||||
env={
|
||||
**os.environ,
|
||||
**project.get_environment(),
|
||||
"CATFISH_IDENT": process.ident,
|
||||
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
||||
},
|
||||
cwd=project.root
|
||||
)
|
||||
|
||||
|
||||
async def run_process_command(project: Project, process: Process):
|
||||
proc = await start_process(project, process)
|
||||
asyncio.ensure_future(publish_stdout_for(proc, process))
|
||||
return proc
|
||||
|
||||
|
@ -96,7 +105,7 @@ async def client_connected(reader, writer):
|
|||
|
||||
|
||||
async def start_server():
|
||||
atexit.register(terminate_processes)
|
||||
atexit.register(terminate_subprocesses)
|
||||
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
||||
click.echo("Started server")
|
||||
await server.serve_forever()
|
||||
|
|
|
@ -67,6 +67,7 @@ class BaseWorkerTestCase(BaseTestCase):
|
|||
time.sleep(0.1)
|
||||
|
||||
def tearDown(self):
|
||||
terminate_subprocesses(psutil.Process(self.worker_process.pid))
|
||||
self.worker_process.terminate()
|
||||
self.worker_process.wait()
|
||||
super().tearDown()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import psutil
|
||||
|
||||
from catfish.project import Project
|
||||
from catfish.utils.processes import is_process_running
|
||||
from catfish.utils.processes import find_running_process_for, is_process_running
|
||||
from catfish.worker import BASE_SOCKET_DIR
|
||||
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
||||
from tests import BaseWorkerTestCase
|
||||
|
@ -25,6 +25,7 @@ class ProcessWorkerTestCase(BaseWorkerTestCase):
|
|||
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||
)
|
||||
self.assertTrue(is_process_running(response["pid"]))
|
||||
self.assertEqual(find_running_process_for(self.process).pid, response["pid"])
|
||||
|
||||
def test_additional_environment(self):
|
||||
response = send_to_server(
|
||||
|
|
Reference in a new issue