Add function to get root process
This commit is contained in:
parent
3ed8ba7eb3
commit
c2c5f955e8
4 changed files with 52 additions and 1 deletions
|
@ -2,6 +2,8 @@ from typing import List
|
|||
|
||||
import psutil
|
||||
|
||||
CURRENT_PROCESS = psutil.Process()
|
||||
|
||||
|
||||
def terminate_processes(procs: List[psutil.Process], timeout=3):
|
||||
# https://psutil.readthedocs.io/en/latest/#terminate-my-children
|
||||
|
@ -14,3 +16,11 @@ def terminate_processes(procs: List[psutil.Process], timeout=3):
|
|||
for p in alive:
|
||||
p.kill()
|
||||
gone, alive = psutil.wait_procs(alive, timeout=timeout)
|
||||
return alive
|
||||
|
||||
|
||||
def get_root_process() -> psutil.Process:
|
||||
proc = CURRENT_PROCESS
|
||||
while proc.parent():
|
||||
proc = proc.parent()
|
||||
return proc
|
||||
|
|
|
@ -5,9 +5,11 @@ from aiohttp import web
|
|||
async def handle_request(request):
|
||||
return web.json_response({})
|
||||
|
||||
|
||||
def get_server():
|
||||
return web.Server(handle_request)
|
||||
|
||||
|
||||
async def run_server(loop, port):
|
||||
click.echo("Starting server...")
|
||||
aiohttp_server = get_server()
|
||||
|
|
|
@ -3,6 +3,7 @@ from aiohttp.test_utils import unused_port
|
|||
from click.testing import CliRunner
|
||||
from catfish.__main__ import cli
|
||||
from catfish import worker
|
||||
from catfish.utils.processes import terminate_processes, CURRENT_PROCESS
|
||||
import functools
|
||||
from multiprocessing import Process
|
||||
|
||||
|
@ -16,13 +17,16 @@ class BaseTestCase(AsyncTestCase):
|
|||
|
||||
def tearDown(self):
|
||||
worker.stop_worker()
|
||||
terminate_processes(CURRENT_PROCESS.children(recursive=True))
|
||||
|
||||
|
||||
class BaseWorkerTestCase(BaseTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.unused_port = unused_port()
|
||||
self.router_process = Process(target=worker.run, args=(self.unused_port,), daemon=True)
|
||||
self.router_process = Process(
|
||||
target=worker.run, args=(self.unused_port,), daemon=True
|
||||
)
|
||||
self.router_process.start()
|
||||
|
||||
def tearDown(self):
|
||||
|
|
35
tests/test_utils/test_processes.py
Normal file
35
tests/test_utils/test_processes.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
import psutil
|
||||
|
||||
from catfish.utils import processes
|
||||
from tests import BaseTestCase
|
||||
|
||||
|
||||
class RootProcessTestCase(BaseTestCase):
|
||||
def test_root_processes(self):
|
||||
root_process = processes.get_root_process()
|
||||
self.assertIsNone(root_process.parent())
|
||||
all_children = [proc.pid for proc in root_process.children(recursive=True)]
|
||||
self.assertIn(os.getpid(), all_children)
|
||||
|
||||
def test_current_process(self):
|
||||
self.assertEqual(processes.CURRENT_PROCESS.pid, os.getpid())
|
||||
|
||||
|
||||
class TerminateProcessesTestCase(BaseTestCase):
|
||||
def test_kills_lots_of_processes(self):
|
||||
created_processes = []
|
||||
for _ in range(10):
|
||||
created_processes.append(subprocess.Popen("yes", stdout=subprocess.PIPE))
|
||||
|
||||
for proc in created_processes:
|
||||
self.assertIsNone(proc.poll())
|
||||
|
||||
still_alive = processes.terminate_processes(
|
||||
[psutil.Process(proc.pid) for proc in created_processes]
|
||||
)
|
||||
self.assertIsNone(still_alive)
|
||||
for proc in created_processes:
|
||||
self.assertEqual(proc.poll(), 0)
|
Reference in a new issue