Change payload type to enum

This commit is contained in:
Jake Howard 2018-12-18 20:41:23 +00:00
parent f5d27f558e
commit 9b937d25c3
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 23 additions and 14 deletions

View file

@ -4,6 +4,7 @@ import os
import shlex import shlex
import socket import socket
import subprocess import subprocess
from enum import Enum, auto
from pathlib import Path from pathlib import Path
import click import click
@ -22,10 +23,17 @@ from catfish.utils.sockets import (
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock") WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
def send_to_server(payload): class PayloadType(Enum):
PING = auto()
PROCESS = auto()
def send_to_server(type: PayloadType, payload):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
sock.connect(str(WORKER_SERVER_SOCKET)) sock.connect(str(WORKER_SERVER_SOCKET))
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE) sock.sendall(
ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE
)
return read_all_from_socket(sock) return read_all_from_socket(sock)
@ -59,9 +67,14 @@ async def publish_stdout(process):
await aio.remove_file(socket_path) await aio.remove_file(socket_path)
async def parse_payload(payload):
data = ujson.loads(payload)
return PayloadType(data["type"]), data["payload"]
async def client_connected(reader, writer): async def client_connected(reader, writer):
data = ujson.loads(await reader.readline()) payload_type, data = await parse_payload(await reader.readline())
if data["type"] == "process": if payload_type == PayloadType.PROCESS:
command = shlex.split(data["command"]) command = shlex.split(data["command"])
proc = await asyncio.create_subprocess_exec( proc = await asyncio.create_subprocess_exec(
*command, *command,
@ -71,7 +84,7 @@ async def client_connected(reader, writer):
) )
asyncio.ensure_future(publish_stdout(proc)) asyncio.ensure_future(publish_stdout(proc))
write_data(writer, {"pid": proc.pid}) write_data(writer, {"pid": proc.pid})
elif data["type"] == "ping": elif payload_type == PayloadType.PING:
write_data(writer, {"ping": "pong"}) write_data(writer, {"ping": "pong"})
else: else:
write_data(writer, {"error": "Invalid command"}) write_data(writer, {"error": "Invalid command"})

View file

@ -1,31 +1,27 @@
from catfish.utils.processes import is_process_running from catfish.utils.processes import is_process_running
from catfish.utils.sockets import stdout_socket_for_pid from catfish.utils.sockets import stdout_socket_for_pid
from catfish.worker.server import read_from_stdout_socket, send_to_server from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server
from tests import BaseWorkerTestCase from tests import BaseWorkerTestCase
class WorkerServerTestCase(BaseWorkerTestCase): class WorkerServerTestCase(BaseWorkerTestCase):
def test_server_creates_process(self): def test_server_creates_process(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
self.assertTrue(is_process_running(response["pid"])) self.assertTrue(is_process_running(response["pid"]))
def test_unknown_command_type(self):
response = send_to_server({"type": "nothing"})
self.assertEqual(response, {"error": "Invalid command"})
def test_ping(self): def test_ping(self):
response = send_to_server({"type": "ping"}) response = send_to_server(PayloadType.PING, {})
self.assertEqual(response, {"ping": "pong"}) self.assertEqual(response, {"ping": "pong"})
class ProcessLogsTestCase(BaseWorkerTestCase): class ProcessLogsTestCase(BaseWorkerTestCase):
def test_creates_socket(self): def test_creates_socket(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"]) stdout_socket = stdout_socket_for_pid(response["pid"])
self.assertTrue(stdout_socket.exists()) self.assertTrue(stdout_socket.exists())
def test_gets_logs(self): def test_gets_logs(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)}) response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"]) stdout_socket = stdout_socket_for_pid(response["pid"])
stdout_iter = read_from_stdout_socket(stdout_socket) stdout_iter = read_from_stdout_socket(stdout_socket)
for i in range(3): for i in range(3):