Better handle stopping router in tests
This commit is contained in:
parent
f6c8d5b19d
commit
1e190c7932
4 changed files with 30 additions and 8 deletions
|
@ -3,6 +3,8 @@ from typing import List
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
|
from catfish.project import Process
|
||||||
|
|
||||||
CURRENT_PROCESS = psutil.Process()
|
CURRENT_PROCESS = psutil.Process()
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +22,8 @@ def terminate_processes(procs: List[psutil.Process], timeout=3):
|
||||||
return alive
|
return alive
|
||||||
|
|
||||||
|
|
||||||
def terminate_subprocesses():
|
def terminate_subprocesses(parent: psutil.Process = CURRENT_PROCESS):
|
||||||
return terminate_processes(CURRENT_PROCESS.children(recursive=True))
|
terminate_processes(parent.children(recursive=True))
|
||||||
|
|
||||||
|
|
||||||
def get_root_process() -> psutil.Process:
|
def get_root_process() -> psutil.Process:
|
||||||
|
@ -47,3 +49,12 @@ def wait_for_process_start(pid: int):
|
||||||
def wait_for_process_terminate(pid: int):
|
def wait_for_process_terminate(pid: int):
|
||||||
while is_process_running(pid):
|
while is_process_running(pid):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def find_running_process_for(process: Process):
|
||||||
|
for proc in get_root_process().children(recursive=True):
|
||||||
|
try:
|
||||||
|
if proc.environ().get("CATFISH_IDENT") == process.ident:
|
||||||
|
return proc
|
||||||
|
except psutil.AccessDenied:
|
||||||
|
continue
|
||||||
|
|
|
@ -12,7 +12,7 @@ import zmq
|
||||||
import ujson
|
import ujson
|
||||||
from catfish.project import Process, Project
|
from catfish.project import Process, Project
|
||||||
from catfish.utils import aio
|
from catfish.utils import aio
|
||||||
from catfish.utils.processes import terminate_processes
|
from catfish.utils.processes import CURRENT_PROCESS, terminate_subprocesses
|
||||||
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
||||||
|
|
||||||
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
||||||
|
@ -62,15 +62,24 @@ async def publish_stdout_for(process, ctf_process: Process):
|
||||||
await aio.remove_file(socket_path)
|
await aio.remove_file(socket_path)
|
||||||
|
|
||||||
|
|
||||||
async def run_process_command(project: Project, process: Process):
|
async def start_process(project: Project, process: Process):
|
||||||
command = shlex.split(process.command)
|
command = shlex.split(process.command)
|
||||||
proc = await asyncio.create_subprocess_exec(
|
return await asyncio.create_subprocess_exec(
|
||||||
*command,
|
*command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident},
|
env={
|
||||||
|
**os.environ,
|
||||||
|
**project.get_environment(),
|
||||||
|
"CATFISH_IDENT": process.ident,
|
||||||
|
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
||||||
|
},
|
||||||
cwd=project.root
|
cwd=project.root
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_process_command(project: Project, process: Process):
|
||||||
|
proc = await start_process(project, process)
|
||||||
asyncio.ensure_future(publish_stdout_for(proc, process))
|
asyncio.ensure_future(publish_stdout_for(proc, process))
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
@ -96,7 +105,7 @@ async def client_connected(reader, writer):
|
||||||
|
|
||||||
|
|
||||||
async def start_server():
|
async def start_server():
|
||||||
atexit.register(terminate_processes)
|
atexit.register(terminate_subprocesses)
|
||||||
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
||||||
click.echo("Started server")
|
click.echo("Started server")
|
||||||
await server.serve_forever()
|
await server.serve_forever()
|
||||||
|
|
|
@ -67,6 +67,7 @@ class BaseWorkerTestCase(BaseTestCase):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
terminate_subprocesses(psutil.Process(self.worker_process.pid))
|
||||||
self.worker_process.terminate()
|
self.worker_process.terminate()
|
||||||
self.worker_process.wait()
|
self.worker_process.wait()
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from catfish.project import Project
|
from catfish.project import Project
|
||||||
from catfish.utils.processes import is_process_running
|
from catfish.utils.processes import find_running_process_for, is_process_running
|
||||||
from catfish.worker import BASE_SOCKET_DIR
|
from catfish.worker import BASE_SOCKET_DIR
|
||||||
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
||||||
from tests import BaseWorkerTestCase
|
from tests import BaseWorkerTestCase
|
||||||
|
@ -25,6 +25,7 @@ class ProcessWorkerTestCase(BaseWorkerTestCase):
|
||||||
{"path": str(self.project.root), "process": str(self.process.name)},
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
)
|
)
|
||||||
self.assertTrue(is_process_running(response["pid"]))
|
self.assertTrue(is_process_running(response["pid"]))
|
||||||
|
self.assertEqual(find_running_process_for(self.process).pid, response["pid"])
|
||||||
|
|
||||||
def test_additional_environment(self):
|
def test_additional_environment(self):
|
||||||
response = send_to_server(
|
response = send_to_server(
|
||||||
|
|
Reference in a new issue