Add worker socket server
This commit is contained in:
parent
5b03540601
commit
1d32eeb24a
8 changed files with 107 additions and 6 deletions
|
@ -12,9 +12,9 @@ def get_server():
|
|||
return web.Server(handle_request)
|
||||
|
||||
|
||||
async def run_server(loop, port):
|
||||
click.echo("Starting server...")
|
||||
async def start_router(loop, port):
|
||||
click.echo("Starting router...")
|
||||
aiohttp_server = get_server()
|
||||
aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port)
|
||||
click.echo("Server listening on port {}".format(port))
|
||||
click.echo("Router listening on port {}".format(port))
|
||||
await aio_server.serve_forever()
|
||||
|
|
|
@ -24,3 +24,11 @@ def get_root_process() -> psutil.Process:
|
|||
while proc.parent():
|
||||
proc = proc.parent()
|
||||
return proc
|
||||
|
||||
|
||||
def is_process_running(pid: int) -> bool:
|
||||
try:
|
||||
psutil.Process(pid)
|
||||
return True
|
||||
except psutil.NoSuchProcess:
|
||||
return False
|
||||
|
|
24
catfish/utils/sockets.py
Normal file
24
catfish/utils/sockets.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import select
|
||||
|
||||
import ujson
|
||||
|
||||
BUFFER_SIZE = 4096
|
||||
DEFAULT_SOCKET_READ_TIMEOUT = 0.01
|
||||
NEW_LINE = b"\n"
|
||||
|
||||
|
||||
def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
|
||||
readable, _, _ = select.select([socket], [], [], timeout)
|
||||
return socket in readable
|
||||
|
||||
|
||||
def read_all_from_socket(socket):
|
||||
data = b""
|
||||
while NEW_LINE not in data:
|
||||
if not socket_has_data(socket):
|
||||
break
|
||||
message = socket.recv(BUFFER_SIZE)
|
||||
if message == b"":
|
||||
break
|
||||
data += message
|
||||
return ujson.loads(data)
|
|
@ -5,9 +5,11 @@ import time
|
|||
|
||||
import psutil
|
||||
|
||||
from catfish.router import run_server
|
||||
from catfish.router import start_router
|
||||
from catfish.utils.processes import terminate_processes
|
||||
|
||||
from .server import start_server, WORKER_SERVER_SOCKET
|
||||
|
||||
PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid")
|
||||
|
||||
|
||||
|
@ -31,6 +33,13 @@ def wait_for_worker():
|
|||
time.sleep(0.1)
|
||||
|
||||
|
||||
def wait_for_running_worker():
|
||||
while not all([
|
||||
os.path.exists(WORKER_SERVER_SOCKET)
|
||||
]):
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def stop_worker():
|
||||
if is_running():
|
||||
terminate_processes([get_running_process()])
|
||||
|
@ -38,7 +47,7 @@ def stop_worker():
|
|||
|
||||
async def run_worker(port):
|
||||
loop = asyncio.get_running_loop()
|
||||
await asyncio.gather(run_server(loop, port))
|
||||
await asyncio.gather(start_router(loop, port), start_server())
|
||||
|
||||
|
||||
def run(port=8080):
|
||||
|
|
50
catfish/worker/server.py
Normal file
50
catfish/worker/server.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import click
|
||||
|
||||
import ujson
|
||||
from catfish.utils.sockets import NEW_LINE, read_all_from_socket
|
||||
|
||||
WORKER_SERVER_SOCKET = os.path.join(tempfile.gettempdir(), "catfish.sock")
|
||||
|
||||
|
||||
def send_to_server(payload):
|
||||
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
||||
while True:
|
||||
try:
|
||||
sock.connect(WORKER_SERVER_SOCKET)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(0.1)
|
||||
|
||||
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE)
|
||||
return read_all_from_socket(sock)
|
||||
|
||||
|
||||
def write_data(writer, data):
|
||||
writer.write(ujson.dumps(data).encode())
|
||||
|
||||
|
||||
async def client_connected(reader, writer):
|
||||
data = ujson.loads(await reader.readline())
|
||||
if data["type"] == "process":
|
||||
command = shlex.split(data["command"])
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdout=subprocess.PIPE,
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
)
|
||||
write_data(writer, {"pid": proc.pid})
|
||||
else:
|
||||
write_data(writer, {"error": "Invalid command"})
|
||||
|
||||
|
||||
async def start_server():
|
||||
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
||||
click.echo("Started server")
|
||||
await server.serve_forever()
|
2
setup.py
2
setup.py
|
@ -11,7 +11,7 @@ setup(
|
|||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
pathon_requires=">=3.6",
|
||||
install_requires=["click", "daemonize", "psutil", "aiohttp"],
|
||||
install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
|
||||
entry_points="""
|
||||
[console_scripts]
|
||||
ctf=catfish.__main__:cli
|
||||
|
|
|
@ -30,6 +30,7 @@ class BaseWorkerTestCase(BaseTestCase):
|
|||
target=worker.run, args=(self.unused_port,), daemon=True
|
||||
)
|
||||
self.router_process.start()
|
||||
worker.wait_for_running_worker()
|
||||
|
||||
def tearDown(self):
|
||||
self.router_process.terminate()
|
||||
|
|
9
tests/test_worker/test_server.py
Normal file
9
tests/test_worker/test_server.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from tests import BaseWorkerTestCase
|
||||
from catfish.worker.server import send_to_server
|
||||
from catfish.utils.processes import is_process_running
|
||||
|
||||
|
||||
class WorkerServerTestCase(BaseWorkerTestCase):
|
||||
def test_server_creates_process(self):
|
||||
response = send_to_server({"type": "process", "command": "yes"})
|
||||
self.assertTrue(is_process_running(response['pid']))
|
Reference in a new issue