Pipe stdout to socket
This commit is contained in:
parent
ce29af5459
commit
c722c12ba4
7 changed files with 91 additions and 8 deletions
15
catfish/utils/aio.py
Normal file
15
catfish/utils/aio.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from aiofiles import os as aios
|
||||
|
||||
remove_file = aios.wrap(os.remove)
|
||||
path_exists = aios.wrap(os.path.exists)
|
||||
|
||||
|
||||
async def await_file_exists(path: str):
|
||||
while True:
|
||||
exists = await path_exists(path)
|
||||
if exists:
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
|
@ -1,3 +1,4 @@
|
|||
import time
|
||||
from typing import List
|
||||
|
||||
import psutil
|
||||
|
@ -32,3 +33,8 @@ def is_process_running(pid: int) -> bool:
|
|||
return True
|
||||
except psutil.NoSuchProcess:
|
||||
return False
|
||||
|
||||
|
||||
def wait_for_process(pid: int):
|
||||
while not is_process_running(pid):
|
||||
time.sleep(0.1)
|
||||
|
|
|
@ -20,8 +20,6 @@ def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
|
|||
def read_all_from_socket(socket):
|
||||
data = b""
|
||||
while NEW_LINE not in data:
|
||||
if not socket_has_data(socket):
|
||||
break
|
||||
message = socket.recv(BUFFER_SIZE)
|
||||
if message == b"":
|
||||
break
|
||||
|
@ -35,3 +33,7 @@ def create_base_socket_dir():
|
|||
|
||||
def delete_base_socket_dir():
|
||||
shutil.rmtree(BASE_SOCKET_DIR)
|
||||
|
||||
|
||||
def stdout_socket_for_pid(pid: int) -> str:
|
||||
return os.path.join(BASE_SOCKET_DIR, "{}.stdout.sock".format(pid))
|
||||
|
|
|
@ -6,9 +6,16 @@ import subprocess
|
|||
import time
|
||||
|
||||
import click
|
||||
import zmq
|
||||
|
||||
import ujson
|
||||
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
||||
from catfish.utils import aio
|
||||
from catfish.utils.sockets import (
|
||||
BASE_SOCKET_DIR,
|
||||
NEW_LINE,
|
||||
read_all_from_socket,
|
||||
stdout_socket_for_pid,
|
||||
)
|
||||
|
||||
WORKER_SERVER_SOCKET = os.path.join(BASE_SOCKET_DIR, "catfish.sock")
|
||||
|
||||
|
@ -26,10 +33,36 @@ def send_to_server(payload):
|
|||
return read_all_from_socket(sock)
|
||||
|
||||
|
||||
def read_from_stdout_socket(socket_path):
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.SUB)
|
||||
sock.connect("ipc://" + socket_path)
|
||||
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
while True:
|
||||
yield sock.recv_string().strip()
|
||||
|
||||
|
||||
def write_data(writer, data):
|
||||
writer.write(ujson.dumps(data).encode())
|
||||
|
||||
|
||||
async def publish_stdout(process):
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.PUB)
|
||||
|
||||
socket_path = stdout_socket_for_pid(process.pid)
|
||||
sock.bind("ipc://" + socket_path)
|
||||
while True:
|
||||
output = await process.stdout.readline()
|
||||
if not output:
|
||||
break
|
||||
sock.send_string(output.decode())
|
||||
process.kill()
|
||||
sock.close()
|
||||
ctx.destroy()
|
||||
await aio.remove_file(socket_path)
|
||||
|
||||
|
||||
async def client_connected(reader, writer):
|
||||
data = ujson.loads(await reader.readline())
|
||||
if data["type"] == "process":
|
||||
|
@ -37,13 +70,17 @@ async def client_connected(reader, writer):
|
|||
proc = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
)
|
||||
asyncio.ensure_future(publish_stdout(proc))
|
||||
write_data(writer, {"pid": proc.pid})
|
||||
elif data["type"] == "ping":
|
||||
write_data(writer, {"ping": "pong"})
|
||||
else:
|
||||
write_data(writer, {"error": "Invalid command"})
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
|
||||
async def start_server():
|
||||
|
|
10
setup.py
10
setup.py
|
@ -11,7 +11,15 @@ setup(
|
|||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
pathon_requires=">=3.6",
|
||||
install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
|
||||
install_requires=[
|
||||
"click",
|
||||
"daemonize",
|
||||
"psutil",
|
||||
"aiohttp",
|
||||
"ujson",
|
||||
"pyzmq",
|
||||
"aiofiles",
|
||||
],
|
||||
entry_points="""
|
||||
[console_scripts]
|
||||
ctf=catfish.__main__:cli
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import time
|
||||
from itertools import count
|
||||
|
||||
for num in count():
|
||||
time.sleep(0.5)
|
||||
sys.stdout.write("Round {}\n".format(num))
|
||||
sys.stdout.flush()
|
||||
print("Round {}".format(num)) # noqa: T001
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import os
|
||||
|
||||
from catfish.utils.processes import is_process_running
|
||||
from catfish.worker.server import send_to_server
|
||||
from catfish.utils.sockets import stdout_socket_for_pid
|
||||
from catfish.worker.server import read_from_stdout_socket, send_to_server
|
||||
from tests import BaseWorkerTestCase
|
||||
|
||||
|
||||
|
@ -15,3 +18,17 @@ class WorkerServerTestCase(BaseWorkerTestCase):
|
|||
def test_ping(self):
|
||||
response = send_to_server({"type": "ping"})
|
||||
self.assertEqual(response, {"ping": "pong"})
|
||||
|
||||
|
||||
class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||
def test_creates_socket(self):
|
||||
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
|
||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||
self.assertTrue(os.path.exists(stdout_socket))
|
||||
|
||||
def test_gets_logs(self):
|
||||
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
|
||||
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||
stdout_iter = read_from_stdout_socket(stdout_socket)
|
||||
for i in range(3):
|
||||
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
||||
|
|
Reference in a new issue