From 9b937d25c3c3533162b4d972c9e05520017f90fa Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Tue, 18 Dec 2018 20:41:23 +0000 Subject: [PATCH] Change payload type to enum --- catfish/worker/server.py | 23 ++++++++++++++++++----- tests/test_worker/test_server.py | 14 +++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 131c900..51d15d2 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -4,6 +4,7 @@ import os import shlex import socket import subprocess +from enum import Enum, auto from pathlib import Path import click @@ -22,10 +23,17 @@ from catfish.utils.sockets import ( WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock") -def send_to_server(payload): +class PayloadType(Enum): + PING = auto() + PROCESS = auto() + + +def send_to_server(type: PayloadType, payload): with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: sock.connect(str(WORKER_SERVER_SOCKET)) - sock.sendall(ujson.dumps(payload).encode() + NEW_LINE) + sock.sendall( + ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE + ) return read_all_from_socket(sock) @@ -59,9 +67,14 @@ async def publish_stdout(process): await aio.remove_file(socket_path) +async def parse_payload(payload): + data = ujson.loads(payload) + return PayloadType(data["type"]), data["payload"] + + async def client_connected(reader, writer): - data = ujson.loads(await reader.readline()) - if data["type"] == "process": + payload_type, data = await parse_payload(await reader.readline()) + if payload_type == PayloadType.PROCESS: command = shlex.split(data["command"]) proc = await asyncio.create_subprocess_exec( *command, @@ -71,7 +84,7 @@ async def client_connected(reader, writer): ) asyncio.ensure_future(publish_stdout(proc)) write_data(writer, {"pid": proc.pid}) - elif data["type"] == "ping": + elif payload_type == PayloadType.PING: write_data(writer, {"ping": "pong"}) else: write_data(writer, {"error": "Invalid command"}) diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index 92c387c..c6f2982 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -1,31 +1,27 @@ from catfish.utils.processes import is_process_running from catfish.utils.sockets import stdout_socket_for_pid -from catfish.worker.server import read_from_stdout_socket, send_to_server +from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server from tests import BaseWorkerTestCase class WorkerServerTestCase(BaseWorkerTestCase): def test_server_creates_process(self): - response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) + response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) self.assertTrue(is_process_running(response["pid"])) - def test_unknown_command_type(self): - response = send_to_server({"type": "nothing"}) - self.assertEqual(response, {"error": "Invalid command"}) - def test_ping(self): - response = send_to_server({"type": "ping"}) + response = send_to_server(PayloadType.PING, {}) self.assertEqual(response, {"ping": "pong"}) class ProcessLogsTestCase(BaseWorkerTestCase): def test_creates_socket(self): - response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) + response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) stdout_socket = stdout_socket_for_pid(response["pid"]) self.assertTrue(stdout_socket.exists()) def test_gets_logs(self): - response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) + response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) stdout_socket = stdout_socket_for_pid(response["pid"]) stdout_iter = read_from_stdout_socket(stdout_socket) for i in range(3):