diff --git a/catfish/worker/__init__.py b/catfish/worker/__init__.py index 01f2bcf..837883c 100644 --- a/catfish/worker/__init__.py +++ b/catfish/worker/__init__.py @@ -5,7 +5,6 @@ import psutil from catfish.utils.processes import terminate_processes import asyncio from .router import run_server -from functools import partial PID_FILE = os.path.join(tempfile.gettempdir(), "catfish.pid") @@ -36,9 +35,10 @@ def stop_worker(): terminate_processes([get_running_process()]) -async def run_worker(): +async def run_worker(port): loop = asyncio.get_running_loop() - await asyncio.gather(run_server(loop)) + await asyncio.gather(run_server(loop, port)) -run = partial(asyncio.run, run_worker()) +def run(port=8080): + return asyncio.run(run_worker(port)) diff --git a/catfish/worker/router.py b/catfish/worker/router.py index d82162c..6d35000 100644 --- a/catfish/worker/router.py +++ b/catfish/worker/router.py @@ -5,10 +5,12 @@ from aiohttp import web async def handle_request(request): return web.json_response({}) +def get_server(): + return web.Server(handle_request) -async def run_server(loop, port=8080): +async def run_server(loop, port): click.echo("Starting server...") - aiohttp_server = web.Server(handle_request) + aiohttp_server = get_server() aio_server = await loop.create_server(aiohttp_server, "0.0.0.0", port) click.echo("Server listening on port {}".format(port)) await aio_server.serve_forever() diff --git a/dev-requirements.txt b/dev-requirements.txt index a571314..88fb04a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,4 @@ +aiounittest==1.1.0 black==18.9b0 flake8==3.6.0 flake8-comprehensions==1.4.1 diff --git a/tests/__init__.py b/tests/__init__.py index 4cea90f..34f552c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,11 +1,13 @@ -from unittest import TestCase +from aiounittest import AsyncTestCase +from aiohttp.test_utils import unused_port from click.testing import CliRunner from catfish.__main__ import cli from catfish import worker import functools +from multiprocessing import Process -class BaseTestCase(TestCase): +class BaseTestCase(AsyncTestCase): def setUp(self): worker.stop_worker() self.cli_runner = CliRunner() @@ -14,3 +16,15 @@ class BaseTestCase(TestCase): def tearDown(self): worker.stop_worker() + + +class BaseWorkerTestCase(BaseTestCase): + def setUp(self): + super().setUp() + self.unused_port = unused_port() + self.router_process = Process(target=worker.run, args=(self.unused_port,), daemon=True) + self.router_process.start() + + def tearDown(self): + self.router_process.terminate() + super().tearDown() diff --git a/tests/test_worker/__init__.py b/tests/test_worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_worker/test_router.py b/tests/test_worker/test_router.py new file mode 100644 index 0000000..2e5aab6 --- /dev/null +++ b/tests/test_worker/test_router.py @@ -0,0 +1,6 @@ +from tests import BaseWorkerTestCase + + +class RouterTestCase(BaseWorkerTestCase): + def test_thing(self): + self.assertTrue(True)