Change payload type to enum
This commit is contained in:
parent
f5d27f558e
commit
9b937d25c3
2 changed files with 23 additions and 14 deletions
|
@ -4,6 +4,7 @@ import os
|
|||
import shlex
|
||||
import socket
|
||||
import subprocess
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
@ -22,10 +23,17 @@ from catfish.utils.sockets import (
|
|||
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
||||
|
||||
|
||||
def send_to_server(payload):
|
||||
class PayloadType(Enum):
|
||||
PING = auto()
|
||||
PROCESS = auto()
|
||||
|
||||
|
||||
def send_to_server(type: PayloadType, payload):
|
||||
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
||||
sock.connect(str(WORKER_SERVER_SOCKET))
|
||||
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE)
|
||||
sock.sendall(
|
||||
ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE
|
||||
)
|
||||
return read_all_from_socket(sock)
|
||||
|
||||
|
||||
|
@ -59,9 +67,14 @@ async def publish_stdout(process):
|
|||
await aio.remove_file(socket_path)
|
||||
|
||||
|
||||
async def parse_payload(payload):
|
||||
data = ujson.loads(payload)
|
||||
return PayloadType(data["type"]), data["payload"]
|
||||
|
||||
|
||||
async def client_connected(reader, writer):
|
||||
data = ujson.loads(await reader.readline())
|
||||
if data["type"] == "process":
|
||||
payload_type, data = await parse_payload(await reader.readline())
|
||||
if payload_type == PayloadType.PROCESS:
|
||||
command = shlex.split(data["command"])
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
|
@ -71,7 +84,7 @@ async def client_connected(reader, writer):
|
|||
)
|
||||
asyncio.ensure_future(publish_stdout(proc))
|
||||
write_data(writer, {"pid": proc.pid})
|
||||
elif data["type"] == "ping":
|
||||
elif payload_type == PayloadType.PING:
|
||||
write_data(writer, {"ping": "pong"})
|
||||
else:
|
||||
write_data(writer, {"error": "Invalid command"})
|
||||
|
|
|
@ -1,31 +1,27 @@
|
|||
from catfish.utils.processes import is_process_running
|
||||
from catfish.utils.sockets import stdout_socket_for_pid
|
||||
from catfish.worker.server import read_from_stdout_socket, send_to_server
|
||||
from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server
|
||||
from tests import BaseWorkerTestCase
|
||||
|
||||
|
||||
class WorkerServerTestCase(BaseWorkerTestCase):
|
||||
def test_server_creates_process(self):
|
||||
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
|
||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
||||
self.assertTrue(is_process_running(response["pid"]))
|
||||
|
||||
def test_unknown_command_type(self):
|
||||
response = send_to_server({"type": "nothing"})
|
||||
self.assertEqual(response, {"error": "Invalid command"})
|
||||
|
||||
def test_ping(self):
|
||||
response = send_to_server({"type": "ping"})
|
||||
response = send_to_server(PayloadType.PING, {})
|
||||
self.assertEqual(response, {"ping": "pong"})
|
||||
|
||||
|
||||
class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||
def test_creates_socket(self):
|
||||
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
|
||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||
self.assertTrue(stdout_socket.exists())
|
||||
|
||||
def test_gets_logs(self):
|
||||
response = send_to_server({"type": "process", "command": str(self.DUMMY_EXE)})
|
||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||
stdout_iter = read_from_stdout_socket(stdout_socket)
|
||||
for i in range(3):
|
||||
|
|
Reference in a new issue