diff --git a/ipc_unix/pubsub.py b/ipc_unix/pubsub.py new file mode 100644 index 0000000..068dcc7 --- /dev/null +++ b/ipc_unix/pubsub.py @@ -0,0 +1,66 @@ +import select +import socket + +import ujson +from ipc_unix.utils import read_payload + + +class Subscriber: + def __init__(self, socket_path): + self.socket_path = socket_path + self.socket = socket.socket( + socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK + ) + self.socket.connect(self.socket_path) + + def listen(self): + while True: + yield self.get_message() + + def get_message(self): + return read_payload(self.socket) + + def close(self): + self.socket.close() + + +class Publisher: + def __init__(self, socket_path): + self.socket_path = socket_path + self.master_socket = socket.socket( + socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK + ) + self.master_socket.bind(self.socket_path) + self.master_socket.listen() + self.connections = [] + + def close(self): + self.master_socket.close() + self.connections.clear() + + def accept_new_connection(self): + readable, _, _ = select.select([self.master_socket], [], [], 1) + + if self.master_socket in readable: + new_socket, _ = self.master_socket.accept() + self.connections.append(new_socket) + + def write(self, message): + self.accept_new_connection() + + _, writable, errorable = select.select([], self.connections, [], 1) + + dead_sockets = [] + + if writable: + data = ujson.dumps(message).encode() + b"\n" + for sock in writable: + try: + sock.send(data) + except BrokenPipeError: + dead_sockets.append(sock) + + for sock in dead_sockets: + if sock in self.connections: + self.connections.remove(sock) + sock.close() diff --git a/ipc_unix/utils.py b/ipc_unix/utils.py index 6ba2124..04d5234 100644 --- a/ipc_unix/utils.py +++ b/ipc_unix/utils.py @@ -4,7 +4,7 @@ import ujson def read_payload(payload): data = b"" while b"\n" not in data: - message = payload.recv(4096) + message = payload.recv(2) if message == b"": break data += message diff --git a/tests/__init__.py b/tests/__init__.py index d2d5d38..fa15af4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -8,7 +8,7 @@ class EchoServer(server.Server): return request -def get_random_path() -> str: +def get_temp_file_path() -> str: _, temp_file_path = tempfile.mkstemp() os.remove(temp_file_path) return temp_file_path diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py new file mode 100644 index 0000000..94a757a --- /dev/null +++ b/tests/test_pubsub.py @@ -0,0 +1,28 @@ +from unittest import TestCase + +from ipc_unix import pubsub +from tests import get_temp_file_path + + +class PubSubTestCase(TestCase): + def setUp(self): + self.socket_path = get_temp_file_path() + self.publisher = pubsub.Publisher(self.socket_path) + self.subscriber = pubsub.Subscriber(self.socket_path) + + def tearDown(self): + self.publisher.close() + self.subscriber.close() + + def test_transmits(self): + self.publisher.write({"foo": "bar"}) + response = self.subscriber.get_message() + self.assertEqual(response, {"foo": "bar"}) + + def test_buffers_messages(self): + for i in range(5): + self.publisher.write(i) + messages = [] + for i in range(5): + messages.append(self.subscriber.get_message()) + self.assertEqual(messages, [0, 1, 2, 3, 4]) diff --git a/tests/test_server.py b/tests/test_server.py index d7054fa..53c59a6 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,12 +2,12 @@ from functools import partial from unittest import TestCase from ipc_unix import client -from tests import EchoServer, get_random_path +from tests import EchoServer, get_temp_file_path class BasicServerTestCase(TestCase): def setUp(self): - self.socket_path = get_random_path() + self.socket_path = get_temp_file_path() self.server = EchoServer(self.socket_path) self.server.serve_in_thread() self.send_to_client = partial(client.send_to, self.socket_path) @@ -20,8 +20,8 @@ class BasicServerTestCase(TestCase): response = self.send_to_client(data) self.assertEqual(response, data) - def test_sending_array(self): - data = ["foo", "bar"] + def test_sending_string(self): + data = "foo" response = self.send_to_client(data) self.assertEqual(response, data)