archive
/
catfish
Archived
1
Fork 0
This repository has been archived on 2023-03-26. You can view files and clone it, but cannot push or open issues or pull requests.
catfish/tests/test_worker/test_server.py

131 lines
4.8 KiB
Python

import signal
import time
import psutil
from catfish.project import Project
from catfish.utils.processes import (
find_running_process_for,
is_process_running,
terminate_processes,
wait_for_process_terminate,
)
from catfish.worker import BASE_SOCKET_DIR
from catfish.worker.server import PayloadType, read_logs_for_process, send_to_server
from tests import BaseWorkerTestCase
class WorkerServerTestCase(BaseWorkerTestCase):
def test_ping(self):
response = send_to_server(PayloadType.PING, {})
self.assertEqual(response, {"ping": "pong"})
class ProcessWorkerTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("bg")
def test_server_creates_process(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
self.assertTrue(is_process_running(response["pid"]))
self.assertEqual(find_running_process_for(self.process).pid, response["pid"])
def test_additional_environment(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
env = psutil.Process(response["pid"]).environ()
self.assertEqual(env["PYTHONUNBUFFERED"], "1")
self.assertEqual(env["CATFISH_IDENT"], self.process.ident)
self.assertEqual(env["FOO"], "bar")
def test_additional_path(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
env = psutil.Process(response["pid"]).environ()
path_dirs = env["PATH"].split(":")
for path in self.project.get_extra_path():
self.assertIn(str(path), path_dirs)
def test_process_restart(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
psutil.Process(initial_pid).send_signal(signal.SIGHUP)
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = find_running_process_for(self.process)
self.assertNotEqual(new_process.pid, initial_pid)
self.assertFalse(is_process_running(initial_pid))
self.assertTrue(is_process_running(new_process.pid))
def test_process_restart_on_1_exit(self):
response = send_to_server(
PayloadType.PROCESS, {"path": str(self.project.root), "process": "die"}
)
initial_pid = response["pid"]
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = find_running_process_for(self.project.get_process("die"))
self.assertNotEqual(new_process.pid, initial_pid)
self.assertFalse(is_process_running(initial_pid))
self.assertTrue(is_process_running(new_process.pid))
def test_process_restart_on_0_exit(self):
response = send_to_server(
PayloadType.PROCESS, {"path": str(self.project.root), "process": "exit"}
)
initial_pid = response["pid"]
wait_for_process_terminate(initial_pid)
time.sleep(2)
new_process = find_running_process_for(self.project.get_process("exit"))
self.assertNotEqual(new_process.pid, initial_pid)
self.assertFalse(is_process_running(initial_pid))
self.assertTrue(is_process_running(new_process.pid))
def test_process_terminate(self):
response = send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
initial_pid = response["pid"]
terminate_processes([psutil.Process(initial_pid)])
wait_for_process_terminate(initial_pid)
time.sleep(2)
self.assertIsNone(find_running_process_for(self.process))
self.assertFalse(is_process_running(initial_pid))
class ProcessLogsTestCase(BaseWorkerTestCase):
def setUp(self):
super().setUp()
self.project = Project(self.EXAMPLE_DIR)
self.process = self.project.get_process("bg")
def test_creates_socket(self):
send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_socket = BASE_SOCKET_DIR.joinpath(self.process.logs_socket)
self.assertTrue(stdout_socket.exists())
def test_gets_logs(self):
send_to_server(
PayloadType.PROCESS,
{"path": str(self.project.root), "process": str(self.process.name)},
)
stdout_iter = read_logs_for_process(self.process)
for i in range(3):
self.assertEqual(next(stdout_iter), "Round {}".format(i))