Assign ports to processes
This commit is contained in:
parent
cce567d30e
commit
5c04b5f85c
2 changed files with 63 additions and 13 deletions
|
@ -6,9 +6,11 @@ import signal
|
|||
import socket
|
||||
import subprocess
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import zmq
|
||||
from aiohttp.test_utils import unused_port
|
||||
|
||||
import ujson
|
||||
from catfish.project import Process, Project
|
||||
|
@ -47,7 +49,9 @@ def write_data(writer, data):
|
|||
writer.write(ujson.dumps(data).encode())
|
||||
|
||||
|
||||
async def publish_stdout_for(process, ctf_process: Process, project: Project):
|
||||
async def publish_stdout_for(
|
||||
process, ctf_process: Process, project: Project, port: Optional[int]
|
||||
):
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.PUB)
|
||||
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
||||
|
@ -63,32 +67,40 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project):
|
|||
exit_code = process.returncode
|
||||
if exit_code in [-signal.SIGHUP, 0, 1]:
|
||||
# If process gets SIGHUP, or exits cleanly / uncleanly, restart it
|
||||
process = await start_process(project, ctf_process)
|
||||
process = await start_process(project, ctf_process, port)
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.destroy()
|
||||
await aio.remove_file(socket_path)
|
||||
|
||||
|
||||
async def start_process(project: Project, process: Process):
|
||||
command = shlex.split(process.command)
|
||||
async def start_process(project: Project, process: Process, port: Optional[int]):
|
||||
process_env = {
|
||||
**os.environ,
|
||||
**project.get_environment(),
|
||||
"CATFISH_IDENT": process.ident,
|
||||
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
||||
}
|
||||
command = process.command
|
||||
if port is not None:
|
||||
process_env["PORT"] = str(port)
|
||||
command = command.replace("$PORT", str(port))
|
||||
return await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
*shlex.split(command),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
env={
|
||||
**os.environ,
|
||||
**project.get_environment(),
|
||||
"CATFISH_IDENT": process.ident,
|
||||
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
||||
},
|
||||
env=process_env,
|
||||
cwd=project.root
|
||||
)
|
||||
|
||||
|
||||
async def run_process_command(project: Project, process: Process):
|
||||
proc = await start_process(project, process)
|
||||
asyncio.ensure_future(publish_stdout_for(proc, process, project))
|
||||
if "$PORT" in process.command:
|
||||
port = unused_port()
|
||||
else:
|
||||
port = None
|
||||
proc = await start_process(project, process, port)
|
||||
asyncio.ensure_future(publish_stdout_for(proc, process, project, port))
|
||||
return proc
|
||||
|
||||
|
||||
|
|
|
@ -137,3 +137,41 @@ class ProcessLogsTestCase(BaseWorkerTestCase):
|
|||
stdout_iter = read_logs_for_process(self.process)
|
||||
for i in range(3):
|
||||
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
||||
|
||||
|
||||
class ProcessPortTestCase(BaseWorkerTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.project = Project(self.EXAMPLE_DIR)
|
||||
self.process = self.project.get_process("web")
|
||||
|
||||
def test_assigns_port(self):
|
||||
response = send_to_server(
|
||||
PayloadType.PROCESS,
|
||||
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||
)
|
||||
process = psutil.Process(response["pid"])
|
||||
self.assertIn("PORT", process.environ())
|
||||
self.assertEqual(self.process.port, process.environ()["PORT"])
|
||||
|
||||
def test_doesnt_assign_port(self):
|
||||
response = send_to_server(
|
||||
PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"}
|
||||
)
|
||||
process = psutil.Process(response["pid"])
|
||||
self.assertNotIn("PORT", process.environ())
|
||||
self.assertIsNone(self.process.port)
|
||||
|
||||
def test_keeps_port_on_restart(self):
|
||||
response = send_to_server(
|
||||
PayloadType.PROCESS,
|
||||
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||
)
|
||||
initial_pid = response["pid"]
|
||||
initial_process = psutil.Process(initial_pid)
|
||||
port = initial_process.environ()["PORT"]
|
||||
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
|
||||
wait_for_process_terminate(initial_pid)
|
||||
time.sleep(2)
|
||||
new_process = get_running_process_for(self.process)
|
||||
self.assertEqual(new_process.environ()["PORT"], port)
|
||||
|
|
Reference in a new issue