Use process ident for logs socket filename
This commit is contained in:
parent
9b937d25c3
commit
fbdef2ecbf
4 changed files with 51 additions and 32 deletions
|
@ -63,3 +63,7 @@ class Process:
|
||||||
@property
|
@property
|
||||||
def ident(self):
|
def ident(self):
|
||||||
return self.project.name + ":" + self.name
|
return self.project.name + ":" + self.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logs_socket(self):
|
||||||
|
return f"{self.project.name}-{self.name}.sock"
|
||||||
|
|
|
@ -40,7 +40,3 @@ def delete_base_socket_dir():
|
||||||
shutil.rmtree(BASE_SOCKET_DIR)
|
shutil.rmtree(BASE_SOCKET_DIR)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def stdout_socket_for_pid(pid: int) -> Path:
|
|
||||||
return BASE_SOCKET_DIR.joinpath("{}.stdout.sock".format(pid))
|
|
||||||
|
|
|
@ -5,20 +5,15 @@ import shlex
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
|
from catfish.project import Process, Project
|
||||||
from catfish.utils import aio
|
from catfish.utils import aio
|
||||||
from catfish.utils.processes import terminate_processes
|
from catfish.utils.processes import terminate_processes
|
||||||
from catfish.utils.sockets import (
|
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
||||||
BASE_SOCKET_DIR,
|
|
||||||
NEW_LINE,
|
|
||||||
read_all_from_socket,
|
|
||||||
stdout_socket_for_pid,
|
|
||||||
)
|
|
||||||
|
|
||||||
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
||||||
|
|
||||||
|
@ -37,9 +32,10 @@ def send_to_server(type: PayloadType, payload):
|
||||||
return read_all_from_socket(sock)
|
return read_all_from_socket(sock)
|
||||||
|
|
||||||
|
|
||||||
def read_from_stdout_socket(socket_path: Path):
|
def read_logs_for_process(process: Process):
|
||||||
ctx = zmq.Context()
|
ctx = zmq.Context()
|
||||||
sock = ctx.socket(zmq.SUB)
|
sock = ctx.socket(zmq.SUB)
|
||||||
|
socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket))
|
||||||
sock.connect("ipc://" + str(socket_path))
|
sock.connect("ipc://" + str(socket_path))
|
||||||
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
while True:
|
while True:
|
||||||
|
@ -50,12 +46,11 @@ def write_data(writer, data):
|
||||||
writer.write(ujson.dumps(data).encode())
|
writer.write(ujson.dumps(data).encode())
|
||||||
|
|
||||||
|
|
||||||
async def publish_stdout(process):
|
async def publish_stdout_for(process, ctf_process: Process):
|
||||||
ctx = zmq.Context()
|
ctx = zmq.Context()
|
||||||
sock = ctx.socket(zmq.PUB)
|
sock = ctx.socket(zmq.PUB)
|
||||||
|
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
||||||
socket_path = stdout_socket_for_pid(process.pid)
|
sock.bind("ipc://" + socket_path)
|
||||||
sock.bind("ipc://" + str(socket_path))
|
|
||||||
while True:
|
while True:
|
||||||
output = await process.stdout.readline()
|
output = await process.stdout.readline()
|
||||||
if not output:
|
if not output:
|
||||||
|
@ -67,6 +62,19 @@ async def publish_stdout(process):
|
||||||
await aio.remove_file(socket_path)
|
await aio.remove_file(socket_path)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_process_command(project: Project, process: Process):
|
||||||
|
command = shlex.split(process.command)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||||
|
cwd=project.root
|
||||||
|
)
|
||||||
|
asyncio.ensure_future(publish_stdout_for(proc, process))
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
async def parse_payload(payload):
|
async def parse_payload(payload):
|
||||||
data = ujson.loads(payload)
|
data = ujson.loads(payload)
|
||||||
return PayloadType(data["type"]), data["payload"]
|
return PayloadType(data["type"]), data["payload"]
|
||||||
|
@ -75,14 +83,9 @@ async def parse_payload(payload):
|
||||||
async def client_connected(reader, writer):
|
async def client_connected(reader, writer):
|
||||||
payload_type, data = await parse_payload(await reader.readline())
|
payload_type, data = await parse_payload(await reader.readline())
|
||||||
if payload_type == PayloadType.PROCESS:
|
if payload_type == PayloadType.PROCESS:
|
||||||
command = shlex.split(data["command"])
|
project = Project(data["path"])
|
||||||
proc = await asyncio.create_subprocess_exec(
|
process = project.get_process(data["process"])
|
||||||
*command,
|
proc = await run_process_command(project, process)
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
|
||||||
)
|
|
||||||
asyncio.ensure_future(publish_stdout(proc))
|
|
||||||
write_data(writer, {"pid": proc.pid})
|
write_data(writer, {"pid": proc.pid})
|
||||||
elif payload_type == PayloadType.PING:
|
elif payload_type == PayloadType.PING:
|
||||||
write_data(writer, {"ping": "pong"})
|
write_data(writer, {"ping": "pong"})
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
|
from catfish.project import Project
|
||||||
from catfish.utils.processes import is_process_running
|
from catfish.utils.processes import is_process_running
|
||||||
from catfish.utils.sockets import stdout_socket_for_pid
|
from catfish.worker import BASE_SOCKET_DIR
|
||||||
from catfish.worker.server import PayloadType, read_from_stdout_socket, send_to_server
|
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
|
||||||
from tests import BaseWorkerTestCase
|
from tests import BaseWorkerTestCase
|
||||||
|
|
||||||
|
|
||||||
class WorkerServerTestCase(BaseWorkerTestCase):
|
class WorkerServerTestCase(BaseWorkerTestCase):
|
||||||
def test_server_creates_process(self):
|
def test_server_creates_process(self):
|
||||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
project = Project(self.EXAMPLE_DIR)
|
||||||
|
process = project.get_process("bg")
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(project.root), "process": str(process.name)},
|
||||||
|
)
|
||||||
self.assertTrue(is_process_running(response["pid"]))
|
self.assertTrue(is_process_running(response["pid"]))
|
||||||
|
|
||||||
def test_ping(self):
|
def test_ping(self):
|
||||||
|
@ -15,14 +21,24 @@ class WorkerServerTestCase(BaseWorkerTestCase):
|
||||||
|
|
||||||
|
|
||||||
class ProcessLogsTestCase(BaseWorkerTestCase):
|
class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.project = Project(self.EXAMPLE_DIR)
|
||||||
|
self.process = self.project.get_process("bg")
|
||||||
|
|
||||||
def test_creates_socket(self):
|
def test_creates_socket(self):
|
||||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
send_to_server(
|
||||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket)
|
||||||
self.assertTrue(stdout_socket.exists())
|
self.assertTrue(stdout_socket.exists())
|
||||||
|
|
||||||
def test_gets_logs(self):
|
def test_gets_logs(self):
|
||||||
response = send_to_server(PayloadType.PROCESS, {"command": str(self.DUMMY_EXE)})
|
send_to_server(
|
||||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
PayloadType.PROCESS,
|
||||||
stdout_iter = read_from_stdout_socket(stdout_socket)
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
stdout_iter = read_logs_for_process(self.process)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
||||||
|
|
Reference in a new issue