Assign ports to processes

This commit is contained in:
Jake Howard 2018-12-23 12:55:56 +00:00
parent cce567d30e
commit 5c04b5f85c
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 63 additions and 13 deletions

View file

@ -6,9 +6,11 @@ import signal
import socket
import subprocess
from enum import Enum, auto
from typing import Optional
import click
import zmq
from aiohttp.test_utils import unused_port
import ujson
from catfish.project import Process, Project
@ -47,7 +49,9 @@ def write_data(writer, data):
writer.write(ujson.dumps(data).encode())
async def publish_stdout_for(process, ctf_process: Process, project: Project):
async def publish_stdout_for(
process, ctf_process: Process, project: Project, port: Optional[int]
):
ctx = zmq.Context()
sock = ctx.socket(zmq.PUB)
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
@ -63,32 +67,40 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project):
exit_code = process.returncode
if exit_code in [-signal.SIGHUP, 0, 1]:
# If process gets SIGHUP, or exits cleanly / uncleanly, restart it
process = await start_process(project, ctf_process)
process = await start_process(project, ctf_process, port)
finally:
sock.close()
ctx.destroy()
await aio.remove_file(socket_path)
async def start_process(project: Project, process: Process):
command = shlex.split(process.command)
return await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env={
async def start_process(project: Project, process: Process, port: Optional[int]):
process_env = {
**os.environ,
**project.get_environment(),
"CATFISH_IDENT": process.ident,
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
},
}
command = process.command
if port is not None:
process_env["PORT"] = str(port)
command = command.replace("$PORT", str(port))
return await asyncio.create_subprocess_exec(
*shlex.split(command),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=process_env,
cwd=project.root
)
async def run_process_command(project: Project, process: Process):
proc = await start_process(project, process)
asyncio.ensure_future(publish_stdout_for(proc, process, project))
if "$PORT" in process.command:
port = unused_port()
else:
port = None
proc = await start_process(project, process, port)
asyncio.ensure_future(publish_stdout_for(proc, process, project, port))
return proc

View file

@ -137,3 +137,41 @@ class ProcessLogsTestCase(BaseWorkerTestCase):
stdout_iter = read_logs_for_process(self.process)
for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i))
class ProcessPortTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("web")
def test_assigns_port(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
process = psutil.Process(response["pid"])
self.assertIn("PORT", process.environ())
self.assertEqual(self.process.port, process.environ()["PORT"])
def test_doesnt_assign_port(self):
response = send_to_server(
PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"}
)
process = psutil.Process(response["pid"])
self.assertNotIn("PORT", process.environ())
self.assertIsNone(self.process.port)
def test_keeps_port_on_restart(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
initial_process = psutil.Process(initial_pid)
port = initial_process.environ()["PORT"]
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = get_running_process_for(self.process)
self.assertEqual(new_process.environ()["PORT"], port)