This repository has been archived on 2023-03-26. You can view files and clone it, but cannot push or open issues or pull requests.
catfish/catfish/worker/server.py

103 lines
3.0 KiB
Python

import asyncio
import atexit
import os
import shlex
import socket
import subprocess
from enum import Enum, auto
import click
import zmq
import ujson
from catfish.project import Process, Project
from catfish.utils import aio
from catfish.utils.processes import terminate_processes
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
WORKER_SERVER_SOCKET = BASE_SOCKET_DIR.joinpath("catfish.sock")
class PayloadType(Enum):
PING = auto()
PROCESS = auto()
def send_to_server(type: PayloadType, payload):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
sock.connect(str(WORKER_SERVER_SOCKET))
sock.sendall(
ujson.dumps({"type": type.value, "payload": payload}).encode() + NEW_LINE
)
return read_all_from_socket(sock)
def read_logs_for_process(process: Process):
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(process.logs_socket))
sock.connect("ipc://" + str(socket_path))
sock.setsockopt_string(zmq.SUBSCRIBE, "")
while True:
yield sock.recv_string().strip()
def write_data(writer, data):
writer.write(ujson.dumps(data).encode())
async def publish_stdout_for(process, ctf_process: Process):
ctx = zmq.Context()
sock = ctx.socket(zmq.PUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
sock.bind("ipc://" + socket_path)
while True:
output = await process.stdout.readline()
if not output:
break
sock.send_string(output.decode())
process.kill()
sock.close()
ctx.destroy()
await aio.remove_file(socket_path)
async def run_process_command(project: Project, process: Process):
command = shlex.split(process.command)
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, **project.get_environment(), "CATFISH_IDENT": process.ident},
cwd=project.root
)
asyncio.ensure_future(publish_stdout_for(proc, process))
return proc
async def parse_payload(payload):
data = ujson.loads(payload)
return PayloadType(data["type"]), data["payload"]
async def client_connected(reader, writer):
payload_type, data = await parse_payload(await reader.readline())
if payload_type == PayloadType.PROCESS:
project = Project(data["path"])
process = project.get_process(data["process"])
proc = await run_process_command(project, process)
write_data(writer, {"pid": proc.pid})
elif payload_type == PayloadType.PING:
write_data(writer, {"ping": "pong"})
else:
write_data(writer, {"error": "Invalid command"})
await writer.drain()
writer.close()
async def start_server():
atexit.register(terminate_processes)
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
click.echo("Started server")
await server.serve_forever()