diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 4a4d91a..d1fed1b 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -6,9 +6,11 @@ import signal import socket import subprocess from enum import Enum, auto +from typing import Optional import click import zmq +from aiohttp.test_utils import unused_port import ujson from catfish.project import Process, Project @@ -47,7 +49,9 @@ def write_data(writer, data): writer.write(ujson.dumps(data).encode()) -async def publish_stdout_for(process, ctf_process: Process, project: Project): +async def publish_stdout_for( + process, ctf_process: Process, project: Project, port: Optional[int] +): ctx = zmq.Context() sock = ctx.socket(zmq.PUB) socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket)) @@ -63,32 +67,40 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project): exit_code = process.returncode if exit_code in [-signal.SIGHUP, 0, 1]: # If process gets SIGHUP, or exits cleanly / uncleanly, restart it - process = await start_process(project, ctf_process) + process = await start_process(project, ctf_process, port) finally: sock.close() ctx.destroy() await aio.remove_file(socket_path) -async def start_process(project: Project, process: Process): - command = shlex.split(process.command) +async def start_process(project: Project, process: Process, port: Optional[int]): + process_env = { + **os.environ, + **project.get_environment(), + "CATFISH_IDENT": process.ident, + "CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid), + } + command = process.command + if port is not None: + process_env["PORT"] = str(port) + command = command.replace("$PORT", str(port)) return await asyncio.create_subprocess_exec( - *command, + *shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - env={ - **os.environ, - **project.get_environment(), - "CATFISH_IDENT": process.ident, - "CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid), - }, + env=process_env, cwd=project.root ) async def run_process_command(project: Project, process: Process): - proc = await start_process(project, process) - asyncio.ensure_future(publish_stdout_for(proc, process, project)) + if "$PORT" in process.command: + port = unused_port() + else: + port = None + proc = await start_process(project, process, port) + asyncio.ensure_future(publish_stdout_for(proc, process, project, port)) return proc diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index 157c5c2..1902e3c 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -137,3 +137,41 @@ class ProcessLogsTestCase(BaseWorkerTestCase): stdout_iter = read_logs_for_process(self.process) for i in range(3): self.assertEqual(next(stdout_iter), "Round {}".format(i)) + + +class ProcessPortTestCase(BaseWorkerTestCase): + def setUp(self): + super().setUp() + self.project = Project(self.EXAMPLE_DIR) + self.process = self.project.get_process("web") + + def test_assigns_port(self): + response = send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + process = psutil.Process(response["pid"]) + self.assertIn("PORT", process.environ()) + self.assertEqual(self.process.port, process.environ()["PORT"]) + + def test_doesnt_assign_port(self): + response = send_to_server( + PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"} + ) + process = psutil.Process(response["pid"]) + self.assertNotIn("PORT", process.environ()) + self.assertIsNone(self.process.port) + + def test_keeps_port_on_restart(self): + response = send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + initial_pid = response["pid"] + initial_process = psutil.Process(initial_pid) + port = initial_process.environ()["PORT"] + psutil.Process(initial_pid).send_signal(signal.SIGHUP) + wait_for_process_terminate(initial_pid) + time.sleep(2) + new_process = get_running_process_for(self.process) + self.assertEqual(new_process.environ()["PORT"], port)