diff --git a/catfish/worker/__init__.py b/catfish/worker/__init__.py index 9cbeef5..4b44796 100644 --- a/catfish/worker/__init__.py +++ b/catfish/worker/__init__.py @@ -8,7 +8,7 @@ import psutil from catfish.router import start_router from catfish.utils.processes import terminate_processes -from .server import start_server, WORKER_SERVER_SOCKET +from .server import WORKER_SERVER_SOCKET, start_server PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid") @@ -34,9 +34,7 @@ def wait_for_worker(): def wait_for_running_worker(): - while not all([ - os.path.exists(WORKER_SERVER_SOCKET) - ]): + while not all([os.path.exists(WORKER_SERVER_SOCKET)]): time.sleep(0.1) diff --git a/catfish/worker/server.py b/catfish/worker/server.py index ebec9cb..3ffbbda 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -5,6 +5,7 @@ import socket import subprocess import tempfile import time + import click import ujson @@ -40,6 +41,8 @@ async def client_connected(reader, writer): env={**os.environ, "PYTHONUNBUFFERED": "1"} ) write_data(writer, {"pid": proc.pid}) + elif data["type"] == "ping": + write_data(writer, {"ping": "pong"}) else: write_data(writer, {"error": "Invalid command"}) diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index 9a15cbc..79cee12 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,9 +1,17 @@ -from tests import BaseWorkerTestCase -from catfish.worker.server import send_to_server from catfish.utils.processes import is_process_running +from catfish.worker.server import send_to_server +from tests import BaseWorkerTestCase class WorkerServerTestCase(BaseWorkerTestCase): def test_server_creates_process(self): response = send_to_server({"type": "process", "command": "yes"}) - self.assertTrue(is_process_running(response['pid'])) + self.assertTrue(is_process_running(response["pid"])) + + def test_unknown_command_type(self): + response = send_to_server({"type": "nothing"}) + self.assertEqual(response, {"error": "Invalid command"}) + + def test_ping(self): + response = send_to_server({"type": "ping"}) + self.assertEqual(response, {"ping": "pong"})