Assign ports to processes
This commit is contained in:
parent
cce567d30e
commit
5c04b5f85c
2 changed files with 63 additions and 13 deletions
|
@ -6,9 +6,11 @@ import signal
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import zmq
|
import zmq
|
||||||
|
from aiohttp.test_utils import unused_port
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
from catfish.project import Process, Project
|
from catfish.project import Process, Project
|
||||||
|
@ -47,7 +49,9 @@ def write_data(writer, data):
|
||||||
writer.write(ujson.dumps(data).encode())
|
writer.write(ujson.dumps(data).encode())
|
||||||
|
|
||||||
|
|
||||||
async def publish_stdout_for(process, ctf_process: Process, project: Project):
|
async def publish_stdout_for(
|
||||||
|
process, ctf_process: Process, project: Project, port: Optional[int]
|
||||||
|
):
|
||||||
ctx = zmq.Context()
|
ctx = zmq.Context()
|
||||||
sock = ctx.socket(zmq.PUB)
|
sock = ctx.socket(zmq.PUB)
|
||||||
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
socket_path = str(BASE_SOCKET_DIR.joinpath(ctf_process.logs_socket))
|
||||||
|
@ -63,32 +67,40 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project):
|
||||||
exit_code = process.returncode
|
exit_code = process.returncode
|
||||||
if exit_code in [-signal.SIGHUP, 0, 1]:
|
if exit_code in [-signal.SIGHUP, 0, 1]:
|
||||||
# If process gets SIGHUP, or exits cleanly / uncleanly, restart it
|
# If process gets SIGHUP, or exits cleanly / uncleanly, restart it
|
||||||
process = await start_process(project, ctf_process)
|
process = await start_process(project, ctf_process, port)
|
||||||
finally:
|
finally:
|
||||||
sock.close()
|
sock.close()
|
||||||
ctx.destroy()
|
ctx.destroy()
|
||||||
await aio.remove_file(socket_path)
|
await aio.remove_file(socket_path)
|
||||||
|
|
||||||
|
|
||||||
async def start_process(project: Project, process: Process):
|
async def start_process(project: Project, process: Process, port: Optional[int]):
|
||||||
command = shlex.split(process.command)
|
process_env = {
|
||||||
return await asyncio.create_subprocess_exec(
|
|
||||||
*command,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
env={
|
|
||||||
**os.environ,
|
**os.environ,
|
||||||
**project.get_environment(),
|
**project.get_environment(),
|
||||||
"CATFISH_IDENT": process.ident,
|
"CATFISH_IDENT": process.ident,
|
||||||
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
"CATFISH_WORKER_PROCESS": str(CURRENT_PROCESS.pid),
|
||||||
},
|
}
|
||||||
|
command = process.command
|
||||||
|
if port is not None:
|
||||||
|
process_env["PORT"] = str(port)
|
||||||
|
command = command.replace("$PORT", str(port))
|
||||||
|
return await asyncio.create_subprocess_exec(
|
||||||
|
*shlex.split(command),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
env=process_env,
|
||||||
cwd=project.root
|
cwd=project.root
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def run_process_command(project: Project, process: Process):
|
async def run_process_command(project: Project, process: Process):
|
||||||
proc = await start_process(project, process)
|
if "$PORT" in process.command:
|
||||||
asyncio.ensure_future(publish_stdout_for(proc, process, project))
|
port = unused_port()
|
||||||
|
else:
|
||||||
|
port = None
|
||||||
|
proc = await start_process(project, process, port)
|
||||||
|
asyncio.ensure_future(publish_stdout_for(proc, process, project, port))
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -137,3 +137,41 @@ class ProcessLogsTestCase(BaseWorkerTestCase):
|
||||||
stdout_iter = read_logs_for_process(self.process)
|
stdout_iter = read_logs_for_process(self.process)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
self.assertEqual(next(stdout_iter), "Round {}".format(i))
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessPortTestCase(BaseWorkerTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.project = Project(self.EXAMPLE_DIR)
|
||||||
|
self.process = self.project.get_process("web")
|
||||||
|
|
||||||
|
def test_assigns_port(self):
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
process = psutil.Process(response["pid"])
|
||||||
|
self.assertIn("PORT", process.environ())
|
||||||
|
self.assertEqual(self.process.port, process.environ()["PORT"])
|
||||||
|
|
||||||
|
def test_doesnt_assign_port(self):
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"}
|
||||||
|
)
|
||||||
|
process = psutil.Process(response["pid"])
|
||||||
|
self.assertNotIn("PORT", process.environ())
|
||||||
|
self.assertIsNone(self.process.port)
|
||||||
|
|
||||||
|
def test_keeps_port_on_restart(self):
|
||||||
|
response = send_to_server(
|
||||||
|
PayloadType.PROCESS,
|
||||||
|
{"path": str(self.project.root), "process": str(self.process.name)},
|
||||||
|
)
|
||||||
|
initial_pid = response["pid"]
|
||||||
|
initial_process = psutil.Process(initial_pid)
|
||||||
|
port = initial_process.environ()["PORT"]
|
||||||
|
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
|
||||||
|
wait_for_process_terminate(initial_pid)
|
||||||
|
time.sleep(2)
|
||||||
|
new_process = get_running_process_for(self.process)
|
||||||
|
self.assertEqual(new_process.environ()["PORT"], port)
|
||||||
|
|
Reference in a new issue