Add function to get root process

This commit is contained in:
Jake Howard 2018-12-12 22:29:42 +00:00
parent 3ed8ba7eb3
commit c2c5f955e8
Signed by: jake
GPG key ID: 57AFB45680EDD477
4 changed files with 52 additions and 1 deletions

View file

@ -2,6 +2,8 @@ from typing import List
import psutil import psutil
CURRENT_PROCESS = psutil.Process()
def terminate_processes(procs: List[psutil.Process], timeout=3): def terminate_processes(procs: List[psutil.Process], timeout=3):
# https://psutil.readthedocs.io/en/latest/#terminate-my-children # https://psutil.readthedocs.io/en/latest/#terminate-my-children
@ -14,3 +16,11 @@ def terminate_processes(procs: List[psutil.Process], timeout=3):
for p in alive: for p in alive:
p.kill() p.kill()
gone, alive = psutil.wait_procs(alive, timeout=timeout) gone, alive = psutil.wait_procs(alive, timeout=timeout)
return alive
def get_root_process() -> psutil.Process:
proc = CURRENT_PROCESS
while proc.parent():
proc = proc.parent()
return proc

View file

@ -5,9 +5,11 @@ from aiohttp import web
async def handle_request(request): async def handle_request(request):
return web.json_response({}) return web.json_response({})
def get_server(): def get_server():
return web.Server(handle_request) return web.Server(handle_request)
async def run_server(loop, port): async def run_server(loop, port):
click.echo("Starting server...") click.echo("Starting server...")
aiohttp_server = get_server() aiohttp_server = get_server()

View file

@ -3,6 +3,7 @@ from aiohttp.test_utils import unused_port
from click.testing import CliRunner from click.testing import CliRunner
from catfish.__main__ import cli from catfish.__main__ import cli
from catfish import worker from catfish import worker
from catfish.utils.processes import terminate_processes, CURRENT_PROCESS
import functools import functools
from multiprocessing import Process from multiprocessing import Process
@ -16,13 +17,16 @@ class BaseTestCase(AsyncTestCase):
def tearDown(self): def tearDown(self):
worker.stop_worker() worker.stop_worker()
terminate_processes(CURRENT_PROCESS.children(recursive=True))
class BaseWorkerTestCase(BaseTestCase): class BaseWorkerTestCase(BaseTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.unused_port = unused_port() self.unused_port = unused_port()
self.router_process = Process(target=worker.run, args=(self.unused_port,), daemon=True) self.router_process = Process(
target=worker.run, args=(self.unused_port,), daemon=True
)
self.router_process.start() self.router_process.start()
def tearDown(self): def tearDown(self):

View file

@ -0,0 +1,35 @@
import os
import subprocess
import psutil
from catfish.utils import processes
from tests import BaseTestCase
class RootProcessTestCase(BaseTestCase):
def test_root_processes(self):
root_process = processes.get_root_process()
self.assertIsNone(root_process.parent())
all_children = [proc.pid for proc in root_process.children(recursive=True)]
self.assertIn(os.getpid(), all_children)
def test_current_process(self):
self.assertEqual(processes.CURRENT_PROCESS.pid, os.getpid())
class TerminateProcessesTestCase(BaseTestCase):
def test_kills_lots_of_processes(self):
created_processes = []
for _ in range(10):
created_processes.append(subprocess.Popen("yes", stdout=subprocess.PIPE))
for proc in created_processes:
self.assertIsNone(proc.poll())
still_alive = processes.terminate_processes(
[psutil.Process(proc.pid) for proc in created_processes]
)
self.assertIsNone(still_alive)
for proc in created_processes:
self.assertEqual(proc.poll(), 0)