Assign ports to processes

This commit is contained in:
Jake Howard 2018-12-23 12:55:56 +00:00
parent cce567d30e
commit 5c04b5f85c
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 63 additions and 13 deletions

View file

@ -6,9 +6,11 @@ import signal
import socket import socket
import subprocess import subprocess
from enum import Enum, auto from enum import Enum, auto
from typing import Optional
import click import click
import zmq import zmq
from aiohttp.test_utils import unused_port
import ujson import ujson
from catfish.project import Process, Project from catfish.project import Process, Project
@ -47,7 +49,9 @@ def write_data(writer, data):
writer.write(ujson.dumps(data).encode()) writer.write(ujson.dumps(data).encode())
async def publish_stdout_for(process, ctf_process: Process, project: Project): async def publish_stdout_for(
process, ctf_process: Process, project: Project, port: Optional[int]
):
ctx = zmq.Context() ctx = zmq.Context()
sock = ctx.socket(zmq.PUB) sock = ctx.socket(zmq.PUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket)) socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
@ -63,32 +67,40 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project):
exit_code = process.returncode exit_code = process.returncode
if exit_code in [-signal.SIGHUP, 0, 1]: if exit_code in [-signal.SIGHUP, 0, 1]:
# If process gets SIGHUP, or exits cleanly / uncleanly, restart it # If process gets SIGHUP, or exits cleanly / uncleanly, restart it
process = await start_process(project, ctf_process) process = await start_process(project, ctf_process, port)
finally: finally:
sock.close() sock.close()
ctx.destroy() ctx.destroy()
await aio.remove_file(socket_path) await aio.remove_file(socket_path)
async def start_process(project: Project, process: Process): async def start_process(project: Project, process: Process, port: Optional[int]):
command = shlex.split(process.command) process_env = {
**os.environ,
**project.get_environment(),
"CATFISH_IDENT": process.ident,
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
}
command = process.command
if port is not None:
process_env["PORT"] = str(port)
command = command.replace("$PORT", str(port))
return await asyncio.create_subprocess_exec( return await asyncio.create_subprocess_exec(
*command, *shlex.split(command),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
env={ env=process_env,
**os.environ,
**project.get_environment(),
"CATFISH_IDENT": process.ident,
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
},
cwd=project.root cwd=project.root
) )
async def run_process_command(project: Project, process: Process): async def run_process_command(project: Project, process: Process):
proc = await start_process(project, process) if "$PORT" in process.command:
asyncio.ensure_future(publish_stdout_for(proc, process, project)) port = unused_port()
else:
port = None
proc = await start_process(project, process, port)
asyncio.ensure_future(publish_stdout_for(proc, process, project, port))
return proc return proc

View file

@ -137,3 +137,41 @@ class ProcessLogsTestCase(BaseWorkerTestCase):
stdout_iter = read_logs_for_process(self.process) stdout_iter = read_logs_for_process(self.process)
for i in range(3): for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i)) self.assertEqual(next(stdout_iter), "Round {}".format(i))
class ProcessPortTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("web")
def test_assigns_port(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
process = psutil.Process(response["pid"])
self.assertIn("PORT", process.environ())
self.assertEqual(self.process.port, process.environ()["PORT"])
def test_doesnt_assign_port(self):
response = send_to_server(
PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"}
)
process = psutil.Process(response["pid"])
self.assertNotIn("PORT", process.environ())
self.assertIsNone(self.process.port)
def test_keeps_port_on_restart(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
initial_process = psutil.Process(initial_pid)
port = initial_process.environ()["PORT"]
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = get_running_process_for(self.process)
self.assertEqual(new_process.environ()["PORT"], port)