diff --git a/ipc_unix/pubsub.py b/ipc_unix/pubsub.py index 068dcc7..73fb0a7 100644 --- a/ipc_unix/pubsub.py +++ b/ipc_unix/pubsub.py @@ -2,24 +2,34 @@ import select import socket import ujson -from ipc_unix.utils import read_payload +from ipc_unix.utils import DEFAULT_SOCKET_TIMEOUT, read_payload, socket_has_data class Subscriber: def __init__(self, socket_path): self.socket_path = socket_path - self.socket = socket.socket( - socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK - ) + self.socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) self.socket.connect(self.socket_path) + @property + def has_data(self): + return socket_has_data(self.socket) + def listen(self): while True: - yield self.get_message() + yield from self.get_message() - def get_message(self): + def get_messages(self) -> dict: return read_payload(self.socket) + def flush_data(self): + while self.has_data: + yield from self.get_messages() + + def get_latest_message(self): + data = list(self.flush_data()) + return data[-1] if data else None + def close(self): self.socket.close() @@ -27,11 +37,9 @@ class Subscriber: class Publisher: def __init__(self, socket_path): self.socket_path = socket_path - self.master_socket = socket.socket( - socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK - ) + self.master_socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) self.master_socket.bind(self.socket_path) - self.master_socket.listen() + self.master_socket.listen(1) self.connections = [] def close(self): @@ -39,16 +47,16 @@ class Publisher: self.connections.clear() def accept_new_connection(self): - readable, _, _ = select.select([self.master_socket], [], [], 1) - - if self.master_socket in readable: + if socket_has_data(self.master_socket): new_socket, _ = self.master_socket.accept() self.connections.append(new_socket) - def write(self, message): + def write(self, message: dict): self.accept_new_connection() - _, writable, errorable = select.select([], self.connections, [], 1) + _, writable, errorable = select.select( + [], self.connections, [], DEFAULT_SOCKET_TIMEOUT + ) dead_sockets = [] diff --git a/ipc_unix/simple.py b/ipc_unix/simple.py index 66afe9f..cec737f 100644 --- a/ipc_unix/simple.py +++ b/ipc_unix/simple.py @@ -6,19 +6,19 @@ import ujson from ipc_unix.utils import read_payload -def send_to(socket_path, data): +def send_to(socket_path, data: dict): with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: sock.connect(socket_path) sock.sendall(ujson.dumps(data).encode() + b"\n") - return read_payload(sock) + return read_payload(sock)[0] class RequestHandler(socketserver.BaseRequestHandler): - def handle_request(self, request): + def handle_request(self, request: dict): raise NotImplementedError("Failed to override `handle_request`") def handle(self): - data = read_payload(self.request) + data = read_payload(self.request)[0] response = self.handle_request(data) self.request.sendall(ujson.dumps(response).encode()) @@ -45,5 +45,5 @@ class Server: self.shutdown() self.server.server_close() - def handle_request(self, request): + def handle_request(self, request: dict): raise NotImplementedError("Must override `handle_request`") diff --git a/ipc_unix/utils.py b/ipc_unix/utils.py index 04d5234..10b0254 100644 --- a/ipc_unix/utils.py +++ b/ipc_unix/utils.py @@ -1,11 +1,28 @@ +import select + import ujson +BUFFER_SIZE = 4096 +DEFAULT_SOCKET_TIMEOUT = 0.1 + + +def socket_has_data(sock, timeout=DEFAULT_SOCKET_TIMEOUT) -> bool: + readable, _, _ = select.select([sock], [], [], timeout) + return sock in readable + def read_payload(payload): data = b"" while b"\n" not in data: - message = payload.recv(2) + if not socket_has_data(payload): + break + message = payload.recv(BUFFER_SIZE) if message == b"": break data += message - return ujson.loads(data) + parsed_data = [] + for row in data.split(b"\n"): + if not row.strip(): + continue + parsed_data.append(ujson.loads(row)) + return parsed_data diff --git a/setup.cfg b/setup.cfg index f50d665..9eea40a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,4 @@ multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 use_parentheses=True +line_length=88 diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 94a757a..be3b83e 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -16,13 +16,16 @@ class PubSubTestCase(TestCase): def test_transmits(self): self.publisher.write({"foo": "bar"}) - response = self.subscriber.get_message() + response = self.subscriber.get_latest_message() self.assertEqual(response, {"foo": "bar"}) + def test_no_messages(self): + self.assertIsNone(self.subscriber.get_latest_message()) + self.assertFalse(self.subscriber.has_data) + def test_buffers_messages(self): for i in range(5): - self.publisher.write(i) - messages = [] - for i in range(5): - messages.append(self.subscriber.get_message()) - self.assertEqual(messages, [0, 1, 2, 3, 4]) + self.publisher.write({"data": i}) + all_messages = self.subscriber.flush_data() + message_ids = [message["data"] for message in all_messages] + self.assertEqual(message_ids, [0, 1, 2, 3, 4]) diff --git a/tests/test_simple.py b/tests/test_simple.py index 740fcb3..5994eda 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -2,6 +2,7 @@ from functools import partial from unittest import TestCase from ipc_unix.simple import send_to +from ipc_unix.utils import BUFFER_SIZE from tests import EchoServer, get_temp_file_path @@ -20,20 +21,11 @@ class SimpleServerTestCase(TestCase): response = self.send_to_client(data) self.assertEqual(response, data) - def test_sending_string(self): - data = "foo" - response = self.send_to_client(data) - self.assertEqual(response, data) - def test_sending_full_buffer(self): - data = ["foo"] * 4096 # Pad out the buffer + data = {"foo" + str(i): i for i in range(BUFFER_SIZE)} response = self.send_to_client(data) self.assertEqual(response, data) - def test_sending_empty_payload(self): - response = self.send_to_client("") - self.assertEqual(response, "") - def test_multiple_send_to_same_server(self): data = {"foo": "bar"} for _ in range(10):