Pipe stdout to socket
This commit is contained in:
parent
ce29af5459
commit
c722c12ba4
7 changed files with 91 additions and 8 deletions
15
catfish/utils/aio.py
Normal file
15
catfish/utils/aio.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from aiofiles import os as aios
|
||||||
|
|
||||||
|
remove_file = aios.wrap(os.remove)
|
||||||
|
path_exists = aios.wrap(os.path.exists)
|
||||||
|
|
||||||
|
|
||||||
|
async def await_file_exists(path: str):
|
||||||
|
while True:
|
||||||
|
exists = await path_exists(path)
|
||||||
|
if exists:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.1)
|
|
@ -1,3 +1,4 @@
|
||||||
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
@ -32,3 +33,8 @@ def is_process_running(pid: int) -> bool:
|
||||||
return True
|
return True
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_process(pid: int):
|
||||||
|
while not is_process_running(pid):
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
|
@ -20,8 +20,6 @@ def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
|
||||||
def read_all_from_socket(socket):
|
def read_all_from_socket(socket):
|
||||||
data = b""
|
data = b""
|
||||||
while NEW_LINE not in data:
|
while NEW_LINE not in data:
|
||||||
if not socket_has_data(socket):
|
|
||||||
break
|
|
||||||
message = socket.recv(BUFFER_SIZE)
|
message = socket.recv(BUFFER_SIZE)
|
||||||
if message == b"":
|
if message == b"":
|
||||||
break
|
break
|
||||||
|
@ -35,3 +33,7 @@ def create_base_socket_dir():
|
||||||
|
|
||||||
def delete_base_socket_dir():
|
def delete_base_socket_dir():
|
||||||
shutil.rmtree(BASE_SOCKET_DIR)
|
shutil.rmtree(BASE_SOCKET_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def stdout_socket_for_pid(pid: int) -> str:
|
||||||
|
return os.path.join(BASE_SOCKET_DIR, "{}.stdout.sock".format(pid))
|
||||||
|
|
|
@ -6,9 +6,16 @@ import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import zmq
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
from catfish.utils.sockets import BASE_SOCKET_DIR, NEW_LINE, read_all_from_socket
|
from catfish.utils import aio
|
||||||
|
from catfish.utils.sockets import (
|
||||||
|
BASE_SOCKET_DIR,
|
||||||
|
NEW_LINE,
|
||||||
|
read_all_from_socket,
|
||||||
|
stdout_socket_for_pid,
|
||||||
|
)
|
||||||
|
|
||||||
WORKER_SERVER_SOCKET = os.path.join(BASE_SOCKET_DIR, "catfish.sock")
|
WORKER_SERVER_SOCKET = os.path.join(BASE_SOCKET_DIR, "catfish.sock")
|
||||||
|
|
||||||
|
@ -26,10 +33,36 @@ def send_to_server(payload):
|
||||||
return read_all_from_socket(sock)
|
return read_all_from_socket(sock)
|
||||||
|
|
||||||
|
|
||||||
|
def read_from_stdout_socket(socket_path):
|
||||||
|
ctx = zmq.Context()
|
||||||
|
sock = ctx.socket(zmq.SUB)
|
||||||
|
sock.connect("ipc://" + socket_path)
|
||||||
|
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
while True:
|
||||||
|
yield sock.recv_string().strip()
|
||||||
|
|
||||||
|
|
||||||
def write_data(writer, data):
|
def write_data(writer, data):
|
||||||
writer.write(ujson.dumps(data).encode())
|
writer.write(ujson.dumps(data).encode())
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_stdout(process):
|
||||||
|
ctx = zmq.Context()
|
||||||
|
sock = ctx.socket(zmq.PUB)
|
||||||
|
|
||||||
|
socket_path = stdout_socket_for_pid(process.pid)
|
||||||
|
sock.bind("ipc://" + socket_path)
|
||||||
|
while True:
|
||||||
|
output = await process.stdout.readline()
|
||||||
|
if not output:
|
||||||
|
break
|
||||||
|
sock.send_string(output.decode())
|
||||||
|
process.kill()
|
||||||
|
sock.close()
|
||||||
|
ctx.destroy()
|
||||||
|
await aio.remove_file(socket_path)
|
||||||
|
|
||||||
|
|
||||||
async def client_connected(reader, writer):
|
async def client_connected(reader, writer):
|
||||||
data = ujson.loads(await reader.readline())
|
data = ujson.loads(await reader.readline())
|
||||||
if data["type"] == "process":
|
if data["type"] == "process":
|
||||||
|
@ -37,13 +70,17 @@ async def client_connected(reader, writer):
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*command,
|
*command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||||
)
|
)
|
||||||
|
asyncio.ensure_future(publish_stdout(proc))
|
||||||
write_data(writer, {"pid": proc.pid})
|
write_data(writer, {"pid": proc.pid})
|
||||||
elif data["type"] == "ping":
|
elif data["type"] == "ping":
|
||||||
write_data(writer, {"ping": "pong"})
|
write_data(writer, {"ping": "pong"})
|
||||||
else:
|
else:
|
||||||
write_data(writer, {"error": "Invalid command"})
|
write_data(writer, {"error": "Invalid command"})
|
||||||
|
await writer.drain()
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
async def start_server():
|
async def start_server():
|
||||||
|
|
10
setup.py
10
setup.py
|
@ -11,7 +11,15 @@ setup(
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
pathon_requires=">=3.6",
|
pathon_requires=">=3.6",
|
||||||
install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
|
install_requires=[
|
||||||
|
"click",
|
||||||
|
"daemonize",
|
||||||
|
"psutil",
|
||||||
|
"aiohttp",
|
||||||
|
"ujson",
|
||||||
|
"pyzmq",
|
||||||
|
"aiofiles",
|
||||||
|
],
|
||||||
entry_points="""
|
entry_points="""
|
||||||
[console_scripts]
|
[console_scripts]
|
||||||
ctf=catfish.__main__:cli
|
ctf=catfish.__main__:cli
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
|
||||||
for num in count():
|
for num in count():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
sys.stdout.write("Round {}\n".format(num))
|
print("Round {}".format(num)) # noqa: T001
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from catfish.utils.processes import is_process_running
|
from catfish.utils.processes import is_process_running
|
||||||
from catfish.worker.server import send_to_server
|
from catfish.utils.sockets import stdout_socket_for_pid
|
||||||
|
from catfish.worker.server import read_from_stdout_socket, send_to_server
|
||||||
from tests import BaseWorkerTestCase
|
from tests import BaseWorkerTestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,3 +18,17 @@ class WorkerServerTestCase(BaseWorkerTestCase):
|
||||||
def test_ping(self):
|
def test_ping(self):
|
||||||
response = send_to_server({"type": "ping"})
|
response = send_to_server({"type": "ping"})
|
||||||
self.assertEqual(response, {"ping": "pong"})
|
self.assertEqual(response, {"ping": "pong"})
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||||
|
def test_creates_socket(self):
|
||||||
|
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
|
||||||
|
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||||
|
self.assertTrue(os.path.exists(stdout_socket))
|
||||||
|
|
||||||
|
def test_gets_logs(self):
|
||||||
|
response = send_to_server({"type": "process", "command": self.DUMMY_EXE})
|
||||||
|
stdout_socket = stdout_socket_for_pid(response["pid"])
|
||||||
|
stdout_iter = read_from_stdout_socket(stdout_socket)
|
||||||
|
for i in range(3):
|
||||||
|
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
||||||
|
|
Reference in a new issue