Pipe stdout to socket

This commit is contained in:
Jake Howard 2018-12-13 16:52:37 +00:00
parent ce29af5459
commit c722c12ba4
Signed by: jake
GPG key ID: 57AFB45680EDD477
7 changed files with 91 additions and 8 deletions

15
catfish/utils/aio.py Normal file
View file

@ -0,0 +1,15 @@
import asyncio
import os
from aiofiles import os as aios
remove_file = aios.wrap(os.remove)
path_exists = aios.wrap(os.path.exists)
async def await_file_exists(path: str):
while True:
exists = await path_exists(path)
if exists:
return
await asyncio.sleep(0.1)

View file

@ -1,3 +1,4 @@
import time
from typing import List
import psutil
@ -32,3 +33,8 @@ def is_process_running(pid: int) -> bool:
return True
except psutil.NoSuchProcess:
return False
def wait_for_process(pid: int):
while not is_process_running(pid):
time.sleep(0.1)

View file

@ -20,8 +20,6 @@ def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
def read_all_from_socket(socket):
data = b""
while NEW_LINE not in data:
if not socket_has_data(socket):
break
message = socket.recv(BUFFER_SIZE)
if message == b"":
break
@ -35,3 +33,7 @@ def create_base_socket_dir():
def delete_base_socket_dir():
shutil.rmtree(BASE_SOCKET_DIR)
def stdout_socket_for_pid(pid: int) -> str:
return os.path.join(BASE_SOCKET_DIR, "{}.stdout.sock".format(pid))

View file

@ -6,9 +6,16 @@ import subprocess
import time
import click
import zmq
import ujson
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
from catfish.utils import aio
from catfish.utils.sockets import (
BASE_SOCKET_DIR,
NEW_LINE,
read_all_from_socket,
stdout_socket_for_pid,
)
WORKER_SERVER_SOCKET = os.path.join(BASE_SOCKET_DIR, "catfish.sock")
@ -26,10 +33,36 @@ def send_to_server(payload):
return read_all_from_socket(sock)
def read_from_stdout_socket(socket_path):
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.connect("ipc://" + socket_path)
sock.setsockopt_string(zmq.SUBSCRIBE, "")
while True:
yield sock.recv_string().strip()
def write_data(writer, data):
writer.write(ujson.dumps(data).encode())
async def publish_stdout(process):
ctx = zmq.Context()
sock = ctx.socket(zmq.PUB)
socket_path = stdout_socket_for_pid(process.pid)
sock.bind("ipc://" + socket_path)
while True:
output = await process.stdout.readline()
if not output:
break
sock.send_string(output.decode())
process.kill()
sock.close()
ctx.destroy()
await aio.remove_file(socket_path)
async def client_connected(reader, writer):
data = ujson.loads(await reader.readline())
if data["type"] == "process":
@ -37,13 +70,17 @@ async def client_connected(reader, writer):
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={**os.environ, "PYTHONUNBUFFERED": "1"}
)
asyncio.ensure_future(publish_stdout(proc))
write_data(writer, {"pid": proc.pid})
elif data["type"] == "ping":
write_data(writer, {"ping": "pong"})
else:
write_data(writer, {"error": "Invalid command"})
await writer.drain()
writer.close()
async def start_server():

View file

@ -11,7 +11,15 @@ setup(
include_package_data=True,
zip_safe=False,
pathon_requires=">=3.6",
install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
install_requires=[
"click",
"daemonize",
"psutil",
"aiohttp",
"ujson",
"pyzmq",
"aiofiles",
],
entry_points="""
[console_scripts]
ctf=catfish.__main__:cli

View file

@ -1,10 +1,8 @@
#!/usr/bin/env python3
import sys
import time
from itertools import count
for num in count():
time.sleep(0.5)
sys.stdout.write("Round {}\n".format(num))
sys.stdout.flush()
print("Round {}".format(num)) # noqa: T001

View file

@ -1,5 +1,8 @@
import os
from catfish.utils.processes import is_process_running
from catfish.worker.server import send_to_server
from catfish.utils.sockets import stdout_socket_for_pid
from catfish.worker.server import read_from_stdout_socket, send_to_server
from tests import BaseWorkerTestCase
@ -15,3 +18,17 @@ class WorkerServerTestCase(BaseWorkerTestCase):
def test_ping(self):
response = send_to_server({"type": "ping"})
self.assertEqual(response, {"ping": "pong"})
class ProcessLogsTestCase(BaseWorkerTestCase):
def test_creates_socket(self):
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
stdout_socket = stdout_socket_for_pid(response["pid"])
self.assertTrue(os.path.exists(stdout_socket))
def test_gets_logs(self):
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
stdout_socket = stdout_socket_for_pid(response["pid"])
stdout_iter = read_from_stdout_socket(stdout_socket)
for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i))