Use process ident for logs socket filename

This commit is contained in:
Jake Howard 2018-12-18 21:04:29 +00:00
parent 9b937d25c3
commit fbdef2ecbf
Signed by: jake
GPG key ID: 57AFB45680EDD477
4 changed files with 51 additions and 32 deletions

View file

@ -63,3 +63,7 @@ class Process:
@property @property
def ident(self): def ident(self):
return self.project.name + ":" + self.name return self.project.name + ":" + self.name
@property
def logs_socket(self):
return f"{self.project.name}-{self.name}.sock"

View file

@ -40,7 +40,3 @@ def delete_base_socket_dir():
shutil.rmtree(BASE_SOCKET_DIR) shutil.rmtree(BASE_SOCKET_DIR)
except FileNotFoundError: except FileNotFoundError:
pass pass
def stdout_socket_for_pid(pid: int) -> Path:
return BASE_SOCKET_DIR.joinpath("{}.stdout.sock".format(pid))

View file

@ -5,20 +5,15 @@ import shlex
import socket import socket
import subprocess import subprocess
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path
import click import click
import zmq import zmq
import ujson import ujson
from catfish.project import Process, Project
from catfish.utils import aio from catfish.utils import aio
from catfish.utils.processes import terminate_processes from catfish.utils.processes import terminate_processes
from catfish.utils.sockets import ( from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
BASE_SOCKET_DIR,
NEW_LINE,
read_all_from_socket,
stdout_socket_for_pid,
)
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock") WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
@ -37,9 +32,10 @@ def send_to_server(type: PayloadType, payload):
return read_all_from_socket(sock) return read_all_from_socket(sock)
def read_from_stdout_socket(socket_path: Path): def read_logs_for_process(process: Process):
ctx = zmq.Context() ctx = zmq.Context()
sock = ctx.socket(zmq.SUB) sock = ctx.socket(zmq.SUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket))
sock.connect("ipc://" + str(socket_path)) sock.connect("ipc://" + str(socket_path))
sock.setsockopt_string(zmq.SUBSCRIBE, "") sock.setsockopt_string(zmq.SUBSCRIBE, "")
while True: while True:
@ -50,12 +46,11 @@ def write_data(writer, data):
writer.write(ujson.dumps(data).encode()) writer.write(ujson.dumps(data).encode())
async def publish_stdout(process): async def publish_stdout_for(process, ctf_process: Process):
ctx = zmq.Context() ctx = zmq.Context()
sock = ctx.socket(zmq.PUB) sock = ctx.socket(zmq.PUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
socket_path = stdout_socket_for_pid(process.pid) sock.bind("ipc://" + socket_path)
sock.bind("ipc://" + str(socket_path))
while True: while True:
output = await process.stdout.readline() output = await process.stdout.readline()
if not output: if not output:
@ -67,6 +62,19 @@ async def publish_stdout(process):
await aio.remove_file(socket_path) await aio.remove_file(socket_path)
async def run_process_command(project: Project, process: Process):
command = shlex.split(process.command)
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONUNBUFFERED": "1"},
cwd=project.root
)
asyncio.ensure_future(publish_stdout_for(proc, process))
return proc
async def parse_payload(payload): async def parse_payload(payload):
data = ujson.loads(payload) data = ujson.loads(payload)
return PayloadType(data["type"]), data["payload"] return PayloadType(data["type"]), data["payload"]
@ -75,14 +83,9 @@ async def parse_payload(payload):
async def client_connected(reader, writer): async def client_connected(reader, writer):
payload_type, data = await parse_payload(await reader.readline()) payload_type, data = await parse_payload(await reader.readline())
if payload_type == PayloadType.PROCESS: if payload_type == PayloadType.PROCESS:
command = shlex.split(data["command"]) project = Project(data["path"])
proc = await asyncio.create_subprocess_exec( process = project.get_process(data["process"])
*command, proc = await run_process_command(project, process)
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONUNBUFFERED": "1"}
)
asyncio.ensure_future(publish_stdout(proc))
write_data(writer, {"pid": proc.pid}) write_data(writer, {"pid": proc.pid})
elif payload_type == PayloadType.PING: elif payload_type == PayloadType.PING:
write_data(writer, {"ping": "pong"}) write_data(writer, {"ping": "pong"})

View file

@ -1,12 +1,18 @@
from catfish.project import Project
from catfish.utils.processes import is_process_running from catfish.utils.processes import is_process_running
from catfish.utils.sockets import stdout_socket_for_pid from catfish.worker import BASE_SOCKET_DIR
from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
from tests import BaseWorkerTestCase from tests import BaseWorkerTestCase
class WorkerServerTestCase(BaseWorkerTestCase): class WorkerServerTestCase(BaseWorkerTestCase):
def test_server_creates_process(self): def test_server_creates_process(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) project = Project(self.EXAMPLE_DIR)
process = project.get_process("bg")
response = send_to_server(
PayloadType.PROCESS,
{"path": str(project.root), "process": str(process.name)},
)
self.assertTrue(is_process_running(response["pid"])) self.assertTrue(is_process_running(response["pid"]))
def test_ping(self): def test_ping(self):
@ -15,14 +21,24 @@ class WorkerServerTestCase(BaseWorkerTestCase):
class ProcessLogsTestCase(BaseWorkerTestCase): class ProcessLogsTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("bg")
def test_creates_socket(self): def test_creates_socket(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) send_to_server(
stdout_socket = stdout_socket_for_pid(response["pid"]) PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket)
self.assertTrue(stdout_socket.exists()) self.assertTrue(stdout_socket.exists())
def test_gets_logs(self): def test_gets_logs(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)}) send_to_server(
stdout_socket = stdout_socket_for_pid(response["pid"]) PayloadType.PROCESS,
stdout_iter = read_from_stdout_socket(stdout_socket) {"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_iter = read_logs_for_process(self.process)
for i in range(3): for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i)) self.assertEqual(next(stdout_iter), "Round {}".format(i))