diff --git a/catfish/utils/aio.py b/catfish/utils/aio.py new file mode 100644 index 0000000..d6c3a49 --- /dev/null +++ b/catfish/utils/aio.py @@ -0,0 +1,15 @@ +import asyncio +import os + +from aiofiles import os as aios + +remove_file = aios.wrap(os.remove) +path_exists = aios.wrap(os.path.exists) + + +async def await_file_exists(path: str): + while True: + exists = await path_exists(path) + if exists: + return + await asyncio.sleep(0.1) diff --git a/catfish/utils/processes.py b/catfish/utils/processes.py index f4d044a..595db9b 100644 --- a/catfish/utils/processes.py +++ b/catfish/utils/processes.py @@ -1,3 +1,4 @@ +import time from typing import List import psutil @@ -32,3 +33,8 @@ def is_process_running(pid: int) -> bool: return True except psutil.NoSuchProcess: return False + + +def wait_for_process(pid: int): + while not is_process_running(pid): + time.sleep(0.1) diff --git a/catfish/utils/sockets.py b/catfish/utils/sockets.py index e6e56a4..cc1e8e6 100644 --- a/catfish/utils/sockets.py +++ b/catfish/utils/sockets.py @@ -20,8 +20,6 @@ def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool: def read_all_from_socket(socket): data = b"" while NEW_LINE not in data: - if not socket_has_data(socket): - break message = socket.recv(BUFFER_SIZE) if message == b"": break @@ -35,3 +33,7 @@ def create_base_socket_dir(): def delete_base_socket_dir(): shutil.rmtree(BASE_SOCKET_DIR) + + +def stdout_socket_for_pid(pid: int) -> str: + return os.path.join(BASE_SOCKET_DIR, "{}.stdout.sock".format(pid)) diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 89d141a..37b3af0 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -6,9 +6,16 @@ import subprocess import time import click +import zmq import ujson -from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket +from catfish.utils import aio +from catfish.utils.sockets import ( + BASE_SOCKET_DIR, + NEW_LINE, + read_all_from_socket, + stdout_socket_for_pid, +) WORKER_SERVER_SOCKET = os.path.join(BASE_SOCKET_DIR, "catfish.sock") @@ -26,10 +33,36 @@ def send_to_server(payload): return read_all_from_socket(sock) +def read_from_stdout_socket(socket_path): + ctx = zmq.Context() + sock = ctx.socket(zmq.SUB) + sock.connect("ipc://" + socket_path) + sock.setsockopt_string(zmq.SUBSCRIBE, "") + while True: + yield sock.recv_string().strip() + + def write_data(writer, data): writer.write(ujson.dumps(data).encode()) +async def publish_stdout(process): + ctx = zmq.Context() + sock = ctx.socket(zmq.PUB) + + socket_path = stdout_socket_for_pid(process.pid) + sock.bind("ipc://" + socket_path) + while True: + output = await process.stdout.readline() + if not output: + break + sock.send_string(output.decode()) + process.kill() + sock.close() + ctx.destroy() + await aio.remove_file(socket_path) + + async def client_connected(reader, writer): data = ujson.loads(await reader.readline()) if data["type"] == "process": @@ -37,13 +70,17 @@ async def client_connected(reader, writer): proc = await asyncio.create_subprocess_exec( *command, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, env={**os.environ, "PYTHONUNBUFFERED": "1"} ) + asyncio.ensure_future(publish_stdout(proc)) write_data(writer, {"pid": proc.pid}) elif data["type"] == "ping": write_data(writer, {"ping": "pong"}) else: write_data(writer, {"error": "Invalid command"}) + await writer.drain() + writer.close() async def start_server(): diff --git a/setup.py b/setup.py index 5af69a6..cc6f0f3 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,15 @@ setup( include_package_data=True, zip_safe=False, pathon_requires=">=3.6", - install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"], + install_requires=[ + "click", + "daemonize", + "psutil", + "aiohttp", + "ujson", + "pyzmq", + "aiofiles", + ], entry_points=""" [console_scripts] ctf=catfish.__main__:cli diff --git a/tests/dummy_program.py b/tests/dummy_program.py index ac44981..9dec57f 100755 --- a/tests/dummy_program.py +++ b/tests/dummy_program.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 -import sys import time from itertools import count for num in count(): time.sleep(0.5) - sys.stdout.write("Round {}\n".format(num)) - sys.stdout.flush() + print("Round {}".format(num)) # noqa: T001 diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index d7f27bf..bbf1000 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,5 +1,8 @@ +import os + from catfish.utils.processes import is_process_running -from catfish.worker.server import send_to_server +from catfish.utils.sockets import stdout_socket_for_pid +from catfish.worker.server import read_from_stdout_socket, send_to_server from tests import BaseWorkerTestCase @@ -15,3 +18,17 @@ class WorkerServerTestCase(BaseWorkerTestCase): def test_ping(self): response = send_to_server({"type": "ping"}) self.assertEqual(response, {"ping": "pong"}) + + +class ProcessLogsTestCase(BaseWorkerTestCase): + def test_creates_socket(self): + response = send_to_server({"type": "process", "command": self.DUMMY_EXE}) + stdout_socket = stdout_socket_for_pid(response["pid"]) + self.assertTrue(os.path.exists(stdout_socket)) + + def test_gets_logs(self): + response = send_to_server({"type": "process", "command": self.DUMMY_EXE}) + stdout_socket = stdout_socket_for_pid(response["pid"]) + stdout_iter = read_from_stdout_socket(stdout_socket) + for i in range(3): + self.assertEqual(next(stdout_iter), "Round {}".format(i))