diff --git a/catfish/worker/server.py b/catfish/worker/server.py index 13271d5..4a4d91a 100644 --- a/catfish/worker/server.py +++ b/catfish/worker/server.py @@ -61,7 +61,8 @@ async def publish_stdout_for(process, ctf_process: Process, project: Project): sock.send_string(output.decode()) await process.wait() exit_code = process.returncode - if exit_code == -signal.SIGHUP: + if exit_code in [-signal.SIGHUP, 0, 1]: + # If process gets SIGHUP, or exits cleanly / uncleanly, restart it process = await start_process(project, ctf_process) finally: sock.close() diff --git a/example/etc/environments/development/Procfile b/example/etc/environments/development/Procfile index f2aab3f..96d95c8 100644 --- a/example/etc/environments/development/Procfile +++ b/example/etc/environments/development/Procfile @@ -1,2 +1,3 @@ web: python -m http.server $PORT bg: python src/dummy_program.py +temp: python src/die_soon.py diff --git a/example/src/dummy_program.py b/example/src/dummy_program.py index 9dec57f..20448e8 100755 --- a/example/src/dummy_program.py +++ b/example/src/dummy_program.py @@ -3,6 +3,7 @@ import time from itertools import count + for num in count(): time.sleep(0.5) print("Round {}".format(num)) # noqa: T001 diff --git a/tests/test_project/test_project.py b/tests/test_project/test_project.py index 1d1a11e..d47a385 100644 --- a/tests/test_project/test_project.py +++ b/tests/test_project/test_project.py @@ -15,7 +15,7 @@ class ProjectTestCase(BaseTestCase): Project("/nonexistent") def test_read_processes(self): - self.assertEqual(len(self.project.processes), 2) + self.assertEqual(len(self.project.processes), 3) web_process = self.project.processes[0] self.assertEqual(web_process.name, "web") self.assertEqual(web_process.command, "python -m http.server $PORT") @@ -26,6 +26,11 @@ class ProjectTestCase(BaseTestCase): self.assertEqual(bg_process.command, "python src/dummy_program.py") self.assertEqual(bg_process.project, self.project) + bg_process = self.project.processes[2] + self.assertEqual(bg_process.name, "temp") + self.assertEqual(bg_process.command, "python src/die_soon.py") + self.assertEqual(bg_process.project, self.project) + def test_get_process(self): self.assertEqual(self.project.get_process("web").name, "web") self.assertEqual(self.project.get_process("bg").name, "bg") diff --git a/tests/test_worker/test_server.py b/tests/test_worker/test_server.py index 1c51472..f6ce7ec 100644 --- a/tests/test_worker/test_server.py +++ b/tests/test_worker/test_server.py @@ -69,6 +69,18 @@ class ProcessWorkerTestCase(BaseWorkerTestCase): self.assertFalse(is_process_running(initial_pid)) self.assertTrue(is_process_running(new_process.pid)) + def test_process_restart_on_0_exit(self): + response = send_to_server( + PayloadType.PROCESS, {"path": str(self.project.root), "process": "temp"} + ) + initial_pid = response["pid"] + wait_for_process_terminate(initial_pid) + time.sleep(2) + new_process = find_running_process_for(self.project.get_process("temp")) + self.assertNotEqual(new_process.pid, initial_pid) + self.assertFalse(is_process_running(initial_pid)) + self.assertTrue(is_process_running(new_process.pid)) + def test_process_terminate(self): response = send_to_server( PayloadType.PROCESS,