import signal import time import psutil from catfish.project import Project from catfish.utils.processes import ( get_running_process_for, is_process_running, terminate_processes, wait_for_process_terminate, ) from catfish.worker import BASE_SOCKET_DIR from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from tests import BaseWorkerTestCase class WorkerServerTestCase(BaseWorkerTestCase): def test_ping(self): response = send_to_server(PayloadType.PING, {}) self.assertEqual(response, {"ping": "pong"}) class ProcessWorkerTestCase(BaseWorkerTestCase): def setUp(self): super().setUp() self.project = Project(self.EXAMPLE_DIR) self.process = self.project.get_process("bg") def test_server_creates_process(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) self.assertTrue(is_process_running(response["pid"])) self.assertEqual(get_running_process_for(self.process).pid, response["pid"]) self.assertTrue(self.process.is_running) def test_additional_environment(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) env = psutil.Process(response["pid"]).environ() self.assertEqual(env["PYTHONUNBUFFERED"], "1") self.assertEqual(env["CATFISH_IDENT"], self.process.ident) self.assertEqual(env["FOO"], "bar") def test_additional_path(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) env = psutil.Process(response["pid"]).environ() path_dirs = env["PATH"].split(":") for path in self.project.get_extra_path(): self.assertIn(str(path), path_dirs) def test_sets_working_dir(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) cwd = psutil.Process(response["pid"]).cwd() self.assertEqual(cwd, str(self.project.root)) def test_process_restart(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) initial_pid = response["pid"] psutil.Process(initial_pid).send_signal(signal.SIGHUP) wait_for_process_terminate(initial_pid) time.sleep(2) new_process = get_running_process_for(self.process) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_restart_on_1_exit(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": "die"} ) initial_pid = response["pid"] wait_for_process_terminate(initial_pid) time.sleep(2) new_process = get_running_process_for(self.project.get_process("die")) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_restart_on_0_exit(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": "exit"} ) initial_pid = response["pid"] wait_for_process_terminate(initial_pid) time.sleep(2) new_process = get_running_process_for(self.project.get_process("exit")) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_terminate(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) initial_pid = response["pid"] terminate_processes([psutil.Process(initial_pid)]) wait_for_process_terminate(initial_pid) time.sleep(2) self.assertIsNone(get_running_process_for(self.process)) self.assertFalse(is_process_running(initial_pid)) class ProcessLogsTestCase(BaseWorkerTestCase): def setUp(self): super().setUp() self.project = Project(self.EXAMPLE_DIR) self.process = self.project.get_process("bg") def test_creates_socket(self): send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket) self.assertTrue(stdout_socket.exists()) def test_gets_logs(self): send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) stdout_iter = read_logs_for_process(self.process) for i in range(3): self.assertEqual(next(stdout_iter), "Round {}".format(i)) class ProcessPortTestCase(BaseWorkerTestCase): def setUp(self): super().setUp() self.project = Project(self.EXAMPLE_DIR) self.process = self.project.get_process("web") def test_assigns_port(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) process = psutil.Process(response["pid"]) self.assertIn("PORT", process.environ()) self.assertEqual(self.process.port, int(process.environ()["PORT"])) def test_doesnt_assign_port(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": "bg"} ) process = psutil.Process(response["pid"]) self.assertNotIn("PORT", process.environ()) self.assertIsNone(self.process.port) def test_keeps_port_on_restart(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) initial_pid = response["pid"] initial_process = psutil.Process(initial_pid) port = initial_process.environ()["PORT"] psutil.Process(initial_pid).send_signal(signal.SIGHUP) wait_for_process_terminate(initial_pid) time.sleep(2) new_process = get_running_process_for(self.process) self.assertEqual(new_process.environ()["PORT"], port)