Add worker socket server
This commit is contained in:
parent
5b03540601
commit
1d32eeb24a
8 changed files with 107 additions and 6 deletions
|
@ -12,9 +12,9 @@ def get_server():
|
||||||
return web.Server(handle_request)
|
return web.Server(handle_request)
|
||||||
|
|
||||||
|
|
||||||
async def run_server(loop, port):
|
async def start_router(loop, port):
|
||||||
click.echo("Starting server...")
|
click.echo("Starting router...")
|
||||||
aiohttp_server = get_server()
|
aiohttp_server = get_server()
|
||||||
aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port)
|
aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port)
|
||||||
click.echo("Server listening on port {}".format(port))
|
click.echo("Router listening on port {}".format(port))
|
||||||
await aio_server.serve_forever()
|
await aio_server.serve_forever()
|
||||||
|
|
|
@ -24,3 +24,11 @@ def get_root_process() -> psutil.Process:
|
||||||
while proc.parent():
|
while proc.parent():
|
||||||
proc = proc.parent()
|
proc = proc.parent()
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def is_process_running(pid: int) -> bool:
|
||||||
|
try:
|
||||||
|
psutil.Process(pid)
|
||||||
|
return True
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return False
|
||||||
|
|
24
catfish/utils/sockets.py
Normal file
24
catfish/utils/sockets.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import select
|
||||||
|
|
||||||
|
import ujson
|
||||||
|
|
||||||
|
BUFFER_SIZE = 4096
|
||||||
|
DEFAULT_SOCKET_READ_TIMEOUT = 0.01
|
||||||
|
NEW_LINE = b"\n"
|
||||||
|
|
||||||
|
|
||||||
|
def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
|
||||||
|
readable, _, _ = select.select([socket], [], [], timeout)
|
||||||
|
return socket in readable
|
||||||
|
|
||||||
|
|
||||||
|
def read_all_from_socket(socket):
|
||||||
|
data = b""
|
||||||
|
while NEW_LINE not in data:
|
||||||
|
if not socket_has_data(socket):
|
||||||
|
break
|
||||||
|
message = socket.recv(BUFFER_SIZE)
|
||||||
|
if message == b"":
|
||||||
|
break
|
||||||
|
data += message
|
||||||
|
return ujson.loads(data)
|
|
@ -5,9 +5,11 @@ import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from catfish.router import run_server
|
from catfish.router import start_router
|
||||||
from catfish.utils.processes import terminate_processes
|
from catfish.utils.processes import terminate_processes
|
||||||
|
|
||||||
|
from .server import start_server, WORKER_SERVER_SOCKET
|
||||||
|
|
||||||
PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid")
|
PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid")
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +33,13 @@ def wait_for_worker():
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_running_worker():
|
||||||
|
while not all([
|
||||||
|
os.path.exists(WORKER_SERVER_SOCKET)
|
||||||
|
]):
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
def stop_worker():
|
def stop_worker():
|
||||||
if is_running():
|
if is_running():
|
||||||
terminate_processes([get_running_process()])
|
terminate_processes([get_running_process()])
|
||||||
|
@ -38,7 +47,7 @@ def stop_worker():
|
||||||
|
|
||||||
async def run_worker(port):
|
async def run_worker(port):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await asyncio.gather(run_server(loop, port))
|
await asyncio.gather(start_router(loop, port), start_server())
|
||||||
|
|
||||||
|
|
||||||
def run(port=8080):
|
def run(port=8080):
|
||||||
|
|
50
catfish/worker/server.py
Normal file
50
catfish/worker/server.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import click
|
||||||
|
|
||||||
|
import ujson
|
||||||
|
from catfish.utils.sockets import NEW_LINE, read_all_from_socket
|
||||||
|
|
||||||
|
WORKER_SERVER_SOCKET = os.path.join(tempfile.gettempdir(), "catfish.sock")
|
||||||
|
|
||||||
|
|
||||||
|
def send_to_server(payload):
|
||||||
|
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
sock.connect(WORKER_SERVER_SOCKET)
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE)
|
||||||
|
return read_all_from_socket(sock)
|
||||||
|
|
||||||
|
|
||||||
|
def write_data(writer, data):
|
||||||
|
writer.write(ujson.dumps(data).encode())
|
||||||
|
|
||||||
|
|
||||||
|
async def client_connected(reader, writer):
|
||||||
|
data = ujson.loads(await reader.readline())
|
||||||
|
if data["type"] == "process":
|
||||||
|
command = shlex.split(data["command"])
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
env={**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||||
|
)
|
||||||
|
write_data(writer, {"pid": proc.pid})
|
||||||
|
else:
|
||||||
|
write_data(writer, {"error": "Invalid command"})
|
||||||
|
|
||||||
|
|
||||||
|
async def start_server():
|
||||||
|
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
|
||||||
|
click.echo("Started server")
|
||||||
|
await server.serve_forever()
|
2
setup.py
2
setup.py
|
@ -11,7 +11,7 @@ setup(
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
pathon_requires=">=3.6",
|
pathon_requires=">=3.6",
|
||||||
install_requires=["click", "daemonize", "psutil", "aiohttp"],
|
install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
|
||||||
entry_points="""
|
entry_points="""
|
||||||
[console_scripts]
|
[console_scripts]
|
||||||
ctf=catfish.__main__:cli
|
ctf=catfish.__main__:cli
|
||||||
|
|
|
@ -30,6 +30,7 @@ class BaseWorkerTestCase(BaseTestCase):
|
||||||
target=worker.run, args=(self.unused_port,), daemon=True
|
target=worker.run, args=(self.unused_port,), daemon=True
|
||||||
)
|
)
|
||||||
self.router_process.start()
|
self.router_process.start()
|
||||||
|
worker.wait_for_running_worker()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.router_process.terminate()
|
self.router_process.terminate()
|
||||||
|
|
9
tests/test_worker/test_server.py
Normal file
9
tests/test_worker/test_server.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from tests import BaseWorkerTestCase
|
||||||
|
from catfish.worker.server import send_to_server
|
||||||
|
from catfish.utils.processes import is_process_running
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerServerTestCase(BaseWorkerTestCase):
|
||||||
|
def test_server_creates_process(self):
|
||||||
|
response = send_to_server({"type": "process", "command": "yes"})
|
||||||
|
self.assertTrue(is_process_running(response['pid']))
|
Reference in a new issue