Use process ident for logs socket filename

This commit is contained in:
Jake Howard 2018-12-18 21:04:29 +00:00
parent 9b937d25c3
commit fbdef2ecbf
Signed by: jake
GPG key ID: 57AFB45680EDD477
4 changed files with 51 additions and 32 deletions

View file

@ -63,3 +63,7 @@ class Process:
@property
def ident(self):
return self.project.name + ":" + self.name
@property
def logs_socket(self):
return f"{self.project.name}-{self.name}.sock"

View file

@ -40,7 +40,3 @@ def delete_base_socket_dir():
shutil.rmtree(BASE_SOCKET_DIR)
except FileNotFoundError:
pass
def stdout_socket_for_pid(pid: int) -> Path:
return BASE_SOCKET_DIR.joinpath("{}.stdout.sock".format(pid))

View file

@ -5,20 +5,15 @@ import shlex
import socket
import subprocess
from enum import Enum, auto
from pathlib import Path
import click
import zmq
import ujson
from catfish.project import Process, Project
from catfish.utils import aio
from catfish.utils.processes import terminate_processes
from catfish.utils.sockets import (
BASE_SOCKET_DIR,
NEW_LINE,
read_all_from_socket,
stdout_socket_for_pid,
)
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
@ -37,9 +32,10 @@ def send_to_server(type: PayloadType, payload):
return read_all_from_socket(sock)
def read_from_stdout_socket(socket_path: Path):
def read_logs_for_process(process: Process):
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket))
sock.connect("ipc://" + str(socket_path))
sock.setsockopt_string(zmq.SUBSCRIBE, "")
while True:
@ -50,12 +46,11 @@ def write_data(writer, data):
writer.write(ujson.dumps(data).encode())
async def publish_stdout(process):
async def publish_stdout_for(process, ctf_process: Process):
ctx = zmq.Context()
sock = ctx.socket(zmq.PUB)
socket_path = stdout_socket_for_pid(process.pid)
sock.bind("ipc://" + str(socket_path))
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
sock.bind("ipc://" + socket_path)
while True:
output = await process.stdout.readline()
if not output:
@ -67,6 +62,19 @@ async def publish_stdout(process):
await aio.remove_file(socket_path)
async def run_process_command(project: Project, process: Process):
command = shlex.split(process.command)
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONUNBUFFERED": "1"},
cwd=project.root
)
asyncio.ensure_future(publish_stdout_for(proc, process))
return proc
async def parse_payload(payload):
data = ujson.loads(payload)
return PayloadType(data["type"]), data["payload"]
@ -75,14 +83,9 @@ async def parse_payload(payload):
async def client_connected(reader, writer):
payload_type, data = await parse_payload(await reader.readline())
if payload_type == PayloadType.PROCESS:
command = shlex.split(data["command"])
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONUNBUFFERED": "1"}
)
asyncio.ensure_future(publish_stdout(proc))
project = Project(data["path"])
process = project.get_process(data["process"])
proc = await run_process_command(project, process)
write_data(writer, {"pid": proc.pid})
elif payload_type == PayloadType.PING:
write_data(writer, {"ping": "pong"})

View file

@ -1,12 +1,18 @@
from catfish.project import Project
from catfish.utils.processes import is_process_running
from catfish.utils.sockets import stdout_socket_for_pid
from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server
from catfish.worker import BASE_SOCKET_DIR
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
from tests import BaseWorkerTestCase
class WorkerServerTestCase(BaseWorkerTestCase):
def test_server_creates_process(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
project = Project(self.EXAMPLE_DIR)
process = project.get_process("bg")
response = send_to_server(
PayloadType.PROCESS,
{"path": str(project.root), "process": str(process.name)},
)
self.assertTrue(is_process_running(response["pid"]))
def test_ping(self):
@ -15,14 +21,24 @@ class WorkerServerTestCase(BaseWorkerTestCase):
class ProcessLogsTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("bg")
def test_creates_socket(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"])
send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket)
self.assertTrue(stdout_socket.exists())
def test_gets_logs(self):
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
stdout_socket = stdout_socket_for_pid(response["pid"])
stdout_iter = read_from_stdout_socket(stdout_socket)
send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_iter = read_logs_for_process(self.process)
for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i))