Change payload type to enum

This commit is contained in:
Jake Howard 2018-12-18 20:41:23 +00:00
parent f5d27f558e
commit 9b937d25c3
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 23 additions and 14 deletions

View file

@ -4,6 +4,7 @@ import os
import shlex
import socket
import subprocess
from enum import Enum, auto
from pathlib import Path
import click
@ -22,10 +23,17 @@ from catfish.utils.sockets import (
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
def send_to_server(payload):
class PayloadType(Enum):
PING = auto()
PROCESS = auto()
def send_to_server(type: PayloadType, payload):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
sock.connect(str(WORKER_SERVER_SOCKET))
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE)
sock.sendall(
ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE
)
return read_all_from_socket(sock)
@ -59,9 +67,14 @@ async def publish_stdout(process):
await aio.remove_file(socket_path)
async def parse_payload(payload):
data = ujson.loads(payload)
return PayloadType(data["type"]), data["payload"]
async def client_connected(reader, writer):
data = ujson.loads(await reader.readline())
if data["type"] == "process":
payload_type, data = await parse_payload(await reader.readline())
if payload_type == PayloadType.PROCESS:
command = shlex.split(data["command"])
proc = await asyncio.create_subprocess_exec(
*command,
@ -71,7 +84,7 @@ async def client_connected(reader, writer):
)
asyncio.ensure_future(publish_stdout(proc))
write_data(writer, {"pid": proc.pid})
elif data["type"] == "ping":
elif payload_type == PayloadType.PING:
write_data(writer, {"ping": "pong"})
else:
write_data(writer, {"error": "Invalid command"})

View file

@ -1,31 +1,27 @@
from catfish.utils.processes import is_process_running
from catfish.utils.sockets import stdout_socket_for_pid
from catfish.worker.server import read_from_stdout_socket, send_to_server
from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server
from tests import BaseWorkerTestCase
class WorkerServerTestCase(BaseWorkerTestCase):
def test_server_creates_process(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
self.assertTrue(is_process_running(response["pid"]))
def test_unknown_command_type(self):
response = send_to_server({"type": "nothing"})
self.assertEqual(response, {"error": "Invalid command"})
def test_ping(self):
response = send_to_server({"type": "ping"})
response = send_to_server(PayloadType.PING, {})
self.assertEqual(response, {"ping": "pong"})
class ProcessLogsTestCase(BaseWorkerTestCase):
def test_creates_socket(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"])
self.assertTrue(stdout_socket.exists())
def test_gets_logs(self):
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"])
stdout_iter = read_from_stdout_socket(stdout_socket)
for i in range(3):