103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
import asyncio
|
|
import atexit
|
|
import os
|
|
import shlex
|
|
import socket
|
|
import subprocess
|
|
from enum import Enum, auto
|
|
|
|
import click
|
|
import zmq
|
|
|
|
import ujson
|
|
from catfish.project import Process, Project
|
|
from catfish.utils import aio
|
|
from catfish.utils.processes import terminate_processes
|
|
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
|
|
|
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
|
|
|
|
|
|
class PayloadType(Enum):
|
|
PING = auto()
|
|
PROCESS = auto()
|
|
|
|
|
|
def send_to_server(type: PayloadType, payload):
|
|
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
|
sock.connect(str(WORKER_SERVER_SOCKET))
|
|
sock.sendall(
|
|
ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE
|
|
)
|
|
return read_all_from_socket(sock)
|
|
|
|
|
|
def read_logs_for_process(process: Process):
|
|
ctx = zmq.Context()
|
|
sock = ctx.socket(zmq.SUB)
|
|
socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket))
|
|
sock.connect("ipc://" + str(socket_path))
|
|
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
|
while True:
|
|
yield sock.recv_string().strip()
|
|
|
|
|
|
def write_data(writer, data):
|
|
writer.write(ujson.dumps(data).encode())
|
|
|
|
|
|
async def publish_stdout_for(process, ctf_process: Process):
|
|
ctx = zmq.Context()
|
|
sock = ctx.socket(zmq.PUB)
|
|
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
|
sock.bind("ipc://" + socket_path)
|
|
while True:
|
|
output = await process.stdout.readline()
|
|
if not output:
|
|
break
|
|
sock.send_string(output.decode())
|
|
process.kill()
|
|
sock.close()
|
|
ctx.destroy()
|
|
await aio.remove_file(socket_path)
|
|
|
|
|
|
async def run_process_command(project: Project, process: Process):
|
|
command = shlex.split(process.command)
|
|
proc = await asyncio.create_subprocess_exec(
|
|
*command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident},
|
|
cwd=project.root
|
|
)
|
|
asyncio.ensure_future(publish_stdout_for(proc, process))
|
|
return proc
|
|
|
|
|
|
async def parse_payload(payload):
|
|
data = ujson.loads(payload)
|
|
return PayloadType(data["type"]), data["payload"]
|
|
|
|
|
|
async def client_connected(reader, writer):
|
|
payload_type, data = await parse_payload(await reader.readline())
|
|
if payload_type == PayloadType.PROCESS:
|
|
project = Project(data["path"])
|
|
process = project.get_process(data["process"])
|
|
proc = await run_process_command(project, process)
|
|
write_data(writer, {"pid": proc.pid})
|
|
elif payload_type == PayloadType.PING:
|
|
write_data(writer, {"ping": "pong"})
|
|
else:
|
|
write_data(writer, {"error": "Invalid command"})
|
|
await writer.drain()
|
|
writer.close()
|
|
|
|
|
|
async def start_server():
|
|
atexit.register(terminate_processes)
|
|
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
|
click.echo("Started server")
|
|
await server.serve_forever()
|