Add worker socket server

This commit is contained in:
Jake Howard 2018-12-13 11:25:09 +00:00
parent 5b03540601
commit 1d32eeb24a
Signed by: jake
GPG key ID: 57AFB45680EDD477
8 changed files with 107 additions and 6 deletions

View file

@ -12,9 +12,9 @@ def get_server():
return web.Server(handle_request) return web.Server(handle_request)
async def run_server(loop, port): async def start_router(loop, port):
click.echo("Starting server...") click.echo("Starting router...")
aiohttp_server = get_server() aiohttp_server = get_server()
aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port) aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port)
click.echo("Server listening on port {}".format(port)) click.echo("Router listening on port {}".format(port))
await aio_server.serve_forever() await aio_server.serve_forever()

View file

@ -24,3 +24,11 @@ def get_root_process() -> psutil.Process:
while proc.parent(): while proc.parent():
proc = proc.parent() proc = proc.parent()
return proc return proc
def is_process_running(pid: int) -> bool:
try:
psutil.Process(pid)
return True
except psutil.NoSuchProcess:
return False

24
catfish/utils/sockets.py Normal file
View file

@ -0,0 +1,24 @@
import select
import ujson
BUFFER_SIZE = 4096
DEFAULT_SOCKET_READ_TIMEOUT = 0.01
NEW_LINE = b"\n"
def socket_has_data(socket, timeout=DEFAULT_SOCKET_READ_TIMEOUT) -> bool:
readable, _, _ = select.select([socket], [], [], timeout)
return socket in readable
def read_all_from_socket(socket):
data = b""
while NEW_LINE not in data:
if not socket_has_data(socket):
break
message = socket.recv(BUFFER_SIZE)
if message == b"":
break
data += message
return ujson.loads(data)

View file

@ -5,9 +5,11 @@ import time
import psutil import psutil
from catfish.router import run_server from catfish.router import start_router
from catfish.utils.processes import terminate_processes from catfish.utils.processes import terminate_processes
from .server import start_server, WORKER_SERVER_SOCKET
PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid") PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid")
@ -31,6 +33,13 @@ def wait_for_worker():
time.sleep(0.1) time.sleep(0.1)
def wait_for_running_worker():
while not all([
os.path.exists(WORKER_SERVER_SOCKET)
]):
time.sleep(0.1)
def stop_worker(): def stop_worker():
if is_running(): if is_running():
terminate_processes([get_running_process()]) terminate_processes([get_running_process()])
@ -38,7 +47,7 @@ def stop_worker():
async def run_worker(port): async def run_worker(port):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await asyncio.gather(run_server(loop, port)) await asyncio.gather(start_router(loop, port), start_server())
def run(port=8080): def run(port=8080):

50
catfish/worker/server.py Normal file
View file

@ -0,0 +1,50 @@
import asyncio
import os
import shlex
import socket
import subprocess
import tempfile
import time
import click
import ujson
from catfish.utils.sockets import NEW_LINE, read_all_from_socket
WORKER_SERVER_SOCKET = os.path.join(tempfile.gettempdir(), "catfish.sock")
def send_to_server(payload):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
while True:
try:
sock.connect(WORKER_SERVER_SOCKET)
break
except ConnectionRefusedError:
time.sleep(0.1)
sock.sendall(ujson.dumps(payload).encode() + NEW_LINE)
return read_all_from_socket(sock)
def write_data(writer, data):
writer.write(ujson.dumps(data).encode())
async def client_connected(reader, writer):
data = ujson.loads(await reader.readline())
if data["type"] == "process":
command = shlex.split(data["command"])
proc = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
env={**os.environ, "PYTHONUNBUFFERED": "1"}
)
write_data(writer, {"pid": proc.pid})
else:
write_data(writer, {"error": "Invalid command"})
async def start_server():
server = await asyncio.start_unix_server(client_connected, WORKER_SERVER_SOCKET)
click.echo("Started server")
await server.serve_forever()

View file

@ -11,7 +11,7 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
pathon_requires=">=3.6", pathon_requires=">=3.6",
install_requires=["click", "daemonize", "psutil", "aiohttp"], install_requires=["click", "daemonize", "psutil", "aiohttp", "ujson"],
entry_points=""" entry_points="""
[console_scripts] [console_scripts]
ctf=catfish.__main__:cli ctf=catfish.__main__:cli

View file

@ -30,6 +30,7 @@ class BaseWorkerTestCase(BaseTestCase):
target=worker.run, args=(self.unused_port,), daemon=True target=worker.run, args=(self.unused_port,), daemon=True
) )
self.router_process.start() self.router_process.start()
worker.wait_for_running_worker()
def tearDown(self): def tearDown(self):
self.router_process.terminate() self.router_process.terminate()

View file

@ -0,0 +1,9 @@
from tests import BaseWorkerTestCase
from catfish.worker.server import send_to_server
from catfish.utils.processes import is_process_running
class WorkerServerTestCase(BaseWorkerTestCase):
def test_server_creates_process(self):
response = send_to_server({"type": "process", "command": "yes"})
self.assertTrue(is_process_running(response['pid']))