diff --git a/ipc_unix/client.py b/ipc_unix/client.py new file mode 100644 index 0000000..8325380 --- /dev/null +++ b/ipc_unix/client.py @@ -0,0 +1,10 @@ +import socket +from ipc_unix.utils import read_payload +import ujson + + +def send_to(socket_path, data): + with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: + sock.connect(socket_path) + sock.sendall(ujson.dumps(data).encode() + b"\n") + return read_payload(sock) diff --git a/ipc_unix/server.py b/ipc_unix/server.py new file mode 100644 index 0000000..55807e9 --- /dev/null +++ b/ipc_unix/server.py @@ -0,0 +1,40 @@ +import socketserver +import threading +import ujson +from ipc_unix.utils import read_payload + + +class RequestHandler(socketserver.BaseRequestHandler): + def handle_request(self, request): + raise NotImplementedError("Failed to override `handle_request`") + + def handle(self): + data = read_payload(self.request) + response = self.handle_request(data) + self.request.sendall(ujson.dumps(response).encode()) + + +class Server: + def __init__(self, socket_path): + class InstanceRequestHandler(RequestHandler): + handle_request = self.handle_request + + self.server = socketserver.UnixStreamServer(socket_path, InstanceRequestHandler) + + def serve_forever(self): + self.server.serve_forever() + + def serve_in_thread(self): + thread = threading.Thread(target=self.serve_forever) + thread.start() + return thread + + def shutdown(self): + self.server.shutdown() + + def close(self): + self.shutdown() + self.server.server_close() + + def handle_request(self, request): + raise NotImplementedError("Must override `handle_request`") diff --git a/ipc_unix/utils.py b/ipc_unix/utils.py new file mode 100644 index 0000000..6ba2124 --- /dev/null +++ b/ipc_unix/utils.py @@ -0,0 +1,11 @@ +import ujson + + +def read_payload(payload): + data = b"" + while b"\n" not in data: + message = payload.recv(4096) + if message == b"": + break + data += message + return ujson.loads(data) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9f60e21 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,14 @@ +from ipc_unix import server +import tempfile +import os + + +class EchoServer(server.Server): + def handle_request(self, request): + return request + + +def get_random_path(): + _, temp_file_path = tempfile.mkstemp() + os.remove(temp_file_path) + return temp_file_path diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..c1882a6 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,40 @@ +from unittest import TestCase +from tests import EchoServer, get_random_path +from ipc_unix import client +from functools import partial + + +class BasicServerTestCase(TestCase): + def setUp(self): + self.socket_path = get_random_path() + self.server = EchoServer(self.socket_path) + self.server.serve_in_thread() + self.send_to_client = partial(client.send_to, self.socket_path) + + def tearDown(self): + self.server.shutdown() + + def test_sending_dict(self): + data = {"foo": "bar"} + response = self.send_to_client(data) + self.assertEqual(response, data) + + def test_sending_array(self): + data = ["foo", "bar"] + response = self.send_to_client(data) + self.assertEqual(response, data) + + def test_sending_full_buffer(self): + data = ["foo"] * 4096 # Pad out the buffer + response = self.send_to_client(data) + self.assertEqual(response, data) + + def test_sending_empty_payload(self): + response = self.send_to_client("") + self.assertEqual(response, "") + + def test_multiple_send_to_same_server(self): + data = {"foo": "bar"} + for _ in range(10): + response = self.send_to_client(data) + self.assertEqual(response, data)