diff --git a/catfish/project/__init__.py b/catfish/project/__init__.py index 5a631e1..dee90a0 100644 --- a/catfish/project/__init__.py +++ b/catfish/project/__init__.py @@ -63,3 +63,7 @@ class Process: @property def ident(self): return self.project.name + ":" + self.name + + @property + def logs_socket(self): + return f"{self.project.name}-{self.name}.sock" diff --git a/catfish/utils/sockets.py b/catfish/utils/sockets.py index bd1b9eb..135e9fa 100644 --- a/catfish/utils/sockets.py +++ b/catfish/utils/sockets.py @@ -40,7 +40,3 @@ def delete_base_socket_dir(): shutil.rmtree(BASE_SOCKET_DIR) except FileNotFoundError: pass - - -def stdout_socket_for_pid(pid: int) -> Path: - return BASE_SOCKET_DIR.joinpath("{}.stdout.sock".format(pid)) diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 51d15d2..3a2b013 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -5,20 +5,15 @@ import shlex import socket import subprocess from enum import Enum, auto -from pathlib import Path import click import zmq import ujson +from catfish.project import Process, Project from catfish.utils import aio from catfish.utils.processes import terminate_processes -from catfish.utils.sockets import ( - BASE_SOCKET_DIR, - NEW_LINE, - read_all_from_socket, - stdout_socket_for_pid, -) +from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock") @@ -37,9 +32,10 @@ def send_to_server(type: PayloadType, payload): return read_all_from_socket(sock) -def read_from_stdout_socket(socket_path: Path): +def read_logs_for_process(process: Process): ctx = zmq.Context() sock = ctx.socket(zmq.SUB) + socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket)) sock.connect("ipc://" + str(socket_path)) sock.setsockopt_string(zmq.SUBSCRIBE, "") while True: @@ -50,12 +46,11 @@ def write_data(writer, data): writer.write(ujson.dumps(data).encode()) -async def publish_stdout(process): +async def publish_stdout_for(process, ctf_process: Process): ctx = zmq.Context() sock = ctx.socket(zmq.PUB) - - socket_path = stdout_socket_for_pid(process.pid) - sock.bind("ipc://" + str(socket_path)) + socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket)) + sock.bind("ipc://" + socket_path) while True: output = await process.stdout.readline() if not output: @@ -67,6 +62,19 @@ async def publish_stdout(process): await aio.remove_file(socket_path) +async def run_process_command(project: Project, process: Process): + command = shlex.split(process.command) + proc = await asyncio.create_subprocess_exec( + *command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + cwd=project.root + ) + asyncio.ensure_future(publish_stdout_for(proc, process)) + return proc + + async def parse_payload(payload): data = ujson.loads(payload) return PayloadType(data["type"]), data["payload"] @@ -75,14 +83,9 @@ async def parse_payload(payload): async def client_connected(reader, writer): payload_type, data = await parse_payload(await reader.readline()) if payload_type == PayloadType.PROCESS: - command = shlex.split(data["command"]) - proc = await asyncio.create_subprocess_exec( - *command, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - env={**os.environ, "PYTHONUNBUFFERED": "1"} - ) - asyncio.ensure_future(publish_stdout(proc)) + project = Project(data["path"]) + process = project.get_process(data["process"]) + proc = await run_process_command(project, process) write_data(writer, {"pid": proc.pid}) elif payload_type == PayloadType.PING: write_data(writer, {"ping": "pong"}) diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index c6f2982..ef10fb1 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,12 +1,18 @@ +from catfish.project import Project from catfish.utils.processes import is_process_running -from catfish.utils.sockets import stdout_socket_for_pid -from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server +from catfish.worker import BASE_SOCKET_DIR +from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from tests import BaseWorkerTestCase class WorkerServerTestCase(BaseWorkerTestCase): def test_server_creates_process(self): - response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) + project = Project(self.EXAMPLE_DIR) + process = project.get_process("bg") + response = send_to_server( + PayloadType.PROCESS, + {"path": str(project.root), "process": str(process.name)}, + ) self.assertTrue(is_process_running(response["pid"])) def test_ping(self): @@ -15,14 +21,24 @@ class WorkerServerTestCase(BaseWorkerTestCase): class ProcessLogsTestCase(BaseWorkerTestCase): + def setUp(self): + super().setUp() + self.project = Project(self.EXAMPLE_DIR) + self.process = self.project.get_process("bg") + def test_creates_socket(self): - response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) - stdout_socket = stdout_socket_for_pid(response["pid"]) + send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket) self.assertTrue(stdout_socket.exists()) def test_gets_logs(self): - response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) - stdout_socket = stdout_socket_for_pid(response["pid"]) - stdout_iter = read_from_stdout_socket(stdout_socket) + send_to_server( + PayloadType.PROCESS, + {"path": str(self.project.root), "process": str(self.process.name)}, + ) + stdout_iter = read_logs_for_process(self.process) for i in range(3): self.assertEqual(next(stdout_iter), "Round {}".format(i))