diff --git a/catfish/router/__init__.py b/catfish/router/__init__.py index a76a605..96c8a87 100644 --- a/catfish/router/__init__.py +++ b/catfish/router/__init__.py @@ -12,9 +12,9 @@ def get_server(): return web.Server(handle_request) -async def run_server(loop, port): - click.echo("Starting server...") +async def start_router(loop, port): + click.echo("Starting router...") aiohttp_server = get_server() aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port) - click.echo("Server listening on port {}".format(port)) + click.echo("Router listening on port {}".format(port)) await aio_server.serve_forever() diff --git a/catfish/utils/processes.py b/catfish/utils/processes.py index a1382a9..f4d044a 100644 --- a/catfish/utils/processes.py +++ b/catfish/utils/processes.py @@ -24,3 +24,11 @@ def get_root_process() -> psutil.Process: while proc.parent(): proc = proc.parent() return proc + + +def is_process_running(pid: int) -> bool: + try: + psutil.Process(pid) + return True + except psutil.NoSuchProcess: + return False diff --git a/catfish/utils/sockets.py b/catfish/utils/sockets.py new file mode 100644 index 0000000..1602b8f --- /dev/null +++ b/catfish/utils/sockets.py @@ -0,0 +1,24 @@ +import select + +import ujson + +BUFFER_SIZE = 4096 +DEFAULT_SOCKET_READ_TIMEOUT = 0.01 +NEW_LINE = b"\n" + + +def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool: + readable, _, _ = select.select([socket], [], [], timeout) + return socket in readable + + +def read_all_from_socket(socket): + data = b"" + while NEW_LINE not in data: + if not socket_has_data(socket): + break + message = socket.recv(BUFFER_SIZE) + if message == b"": + break + data += message + return ujson.loads(data) diff --git a/catfish/worker/__init__.py b/catfish/worker/__init__.py index a9f1e21..9cbeef5 100644 --- a/catfish/worker/__init__.py +++ b/catfish/worker/__init__.py @@ -5,9 +5,11 @@ import time import psutil -from catfish.router import run_server +from catfish.router import start_router from catfish.utils.processes import terminate_processes +from .server import start_server, WORKER_SERVER_SOCKET + PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid") @@ -31,6 +33,13 @@ def wait_for_worker(): time.sleep(0.1) +def wait_for_running_worker(): + while not all([ + os.path.exists(WORKER_SERVER_SOCKET) + ]): + time.sleep(0.1) + + def stop_worker(): if is_running(): terminate_processes([get_running_process()]) @@ -38,7 +47,7 @@ def stop_worker(): async def run_worker(port): loop = asyncio.get_running_loop() - await asyncio.gather(run_server(loop, port)) + await asyncio.gather(start_router(loop, port), start_server()) def run(port=8080): diff --git a/catfish/worker/server.py b/catfish/worker/server.py new file mode 100644 index 0000000..ebec9cb --- /dev/null +++ b/catfish/worker/server.py @@ -0,0 +1,50 @@ +import asyncio +import os +import shlex +import socket +import subprocess +import tempfile +import time +import click + +import ujson +from catfish.utils.sockets import NEW_LINE, read_all_from_socket + +WORKER_SERVER_SOCKET = os.path.join(tempfile.gettempdir(), "catfish.sock") + + +def send_to_server(payload): + with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: + while True: + try: + sock.connect(WORKER_SERVER_SOCKET) + break + except ConnectionRefusedError: + time.sleep(0.1) + + sock.sendall(ujson.dumps(payload).encode() + NEW_LINE) + return read_all_from_socket(sock) + + +def write_data(writer, data): + writer.write(ujson.dumps(data).encode()) + + +async def client_connected(reader, writer): + data = ujson.loads(await reader.readline()) + if data["type"] == "process": + command = shlex.split(data["command"]) + proc = await asyncio.create_subprocess_exec( + *command, + stdout=subprocess.PIPE, + env={**os.environ, "PYTHONUNBUFFERED": "1"} + ) + write_data(writer, {"pid": proc.pid}) + else: + write_data(writer, {"error": "Invalid command"}) + + +async def start_server(): + server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET) + click.echo("Started server") + await server.serve_forever() diff --git a/setup.py b/setup.py index 2101e6f..5af69a6 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( include_package_data=True, zip_safe=False, pathon_requires=">=3.6", - install_requires=["click", "daemonize", "psutil", "aiohttp"], + install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"], entry_points=""" [console_scripts] ctf=catfish.__main__:cli diff --git a/tests/__init__.py b/tests/__init__.py index ce37a8a..a0355e4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -30,6 +30,7 @@ class BaseWorkerTestCase(BaseTestCase): target=worker.run, args=(self.unused_port,), daemon=True ) self.router_process.start() + worker.wait_for_running_worker() def tearDown(self): self.router_process.terminate() diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py new file mode 100644 index 0000000..9a15cbc --- /dev/null +++ b/tests/test_worker/test_server.py @@ -0,0 +1,9 @@ +from tests import BaseWorkerTestCase +from catfish.worker.server import send_to_server +from catfish.utils.processes import is_process_running + + +class WorkerServerTestCase(BaseWorkerTestCase): + def test_server_creates_process(self): + response = send_to_server({"type": "process", "command": "yes"}) + self.assertTrue(is_process_running(response['pid']))