import signal import time import psutil from catfish.project import Project from catfish.utils.processes import ( find_running_process_for, is_process_running, terminate_processes, wait_for_process_terminate, ) from catfish.worker import BASE_SOCKET_DIR from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server from tests import BaseWorkerTestCase class WorkerServerTestCase(BaseWorkerTestCase): def test_ping(self): response = send_to_server(PayloadType.PING, {}) self.assertEqual(response, {"ping": "pong"}) class ProcessWorkerTestCase(BaseWorkerTestCase): def setUp(self): super().setUp() self.project = Project(self.EXAMPLE_DIR) self.process = self.project.get_process("bg") def test_server_creates_process(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) self.assertTrue(is_process_running(response["pid"])) self.assertEqual(find_running_process_for(self.process).pid, response["pid"]) def test_additional_environment(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) env = psutil.Process(response["pid"]).environ() self.assertEqual(env["PYTHONUNBUFFERED"], "1") self.assertEqual(env["CATFISH_IDENT"], self.process.ident) self.assertEqual(env["FOO"], "bar") def test_additional_path(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) env = psutil.Process(response["pid"]).environ() path_dirs = env["PATH"].split(":") for path in self.project.get_extra_path(): self.assertIn(str(path), path_dirs) def test_process_restart(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) initial_pid = response["pid"] psutil.Process(initial_pid).send_signal(signal.SIGHUP) wait_for_process_terminate(initial_pid) time.sleep(2) new_process = find_running_process_for(self.process) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_restart_on_1_exit(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": "die"} ) initial_pid = response["pid"] wait_for_process_terminate(initial_pid) time.sleep(2) new_process = find_running_process_for(self.project.get_process("die")) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_restart_on_0_exit(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": "exit"} ) initial_pid = response["pid"] wait_for_process_terminate(initial_pid) time.sleep(2) new_process = find_running_process_for(self.project.get_process("exit")) self.assertNotEqual(new_process.pid, initial_pid) self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) def test_process_terminate(self): response = send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) initial_pid = response["pid"] terminate_processes([psutil.Process(initial_pid)]) wait_for_process_terminate(initial_pid) time.sleep(2) self.assertIsNone(find_running_process_for(self.process)) self.assertFalse(is_process_running(initial_pid)) class ProcessLogsTestCase(BaseWorkerTestCase): def setUp(self): super().setUp() self.project = Project(self.EXAMPLE_DIR) self.process = self.project.get_process("bg") def test_creates_socket(self): send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket) self.assertTrue(stdout_socket.exists()) def test_gets_logs(self): send_to_server( PayloadType.PROCESS, {"path": str(self.project.root), "process": str(self.process.name)}, ) stdout_iter = read_logs_for_process(self.process) for i in range(3): self.assertEqual(next(stdout_iter), "Round {}".format(i))