diff --git a/ipc_unix/pubsub.py b/ipc_unix/pubsub.py index fd2c072..13696e9 100644 --- a/ipc_unix/pubsub.py +++ b/ipc_unix/pubsub.py @@ -1,6 +1,7 @@ import os import select import socket +import threading import ujson from ipc_unix.utils import read_payload, socket_has_data @@ -46,22 +47,42 @@ class Publisher: self.master_socket.bind(self.socket_path) self.master_socket.listen() self.connections = [] + self.accepting_new_connections = threading.Event() + self.accepting_new_connections.set() + self.new_connections_thread = threading.Thread( + target=self._accept_new_connections + ) + + def start(self): + self.accepting_new_connections.set() + self.new_connections_thread.start() def close(self): + self.accepting_new_connections.clear() + if self.new_connections_thread.is_alive(): + self.new_connections_thread.join() self.master_socket.close() for connection in self.connections: connection.close() self.connections.clear() os.remove(self.socket_path) - def accept_new_connection(self): + def accept_outstanding_connections(self): + if self.new_connections_thread.is_alive(): + raise Exception( + "Cannot accept connections manually whilst thread is running" + ) while socket_has_data(self.master_socket): new_socket, _ = self.master_socket.accept() self.connections.append(new_socket) - def write(self, message: dict): - self.accept_new_connection() + def _accept_new_connections(self): + while self.accepting_new_connections.is_set(): + if socket_has_data(self.master_socket): + new_socket, _ = self.master_socket.accept() + self.connections.append(new_socket) + def write(self, message: dict): _, writable, errorable = select.select([], self.connections, [], 1) dead_sockets = [] diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 38e099f..ae17579 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -9,6 +9,7 @@ class PubSubTestCase(TestCase): self.socket_path = get_temp_file_path() self.publisher = pubsub.Publisher(self.socket_path) self.subscriber = pubsub.Subscriber(self.socket_path) + self.publisher.accept_outstanding_connections() def tearDown(self): self.publisher.close() @@ -39,6 +40,7 @@ class PubSubTestCase(TestCase): def test_multiple_subscribers(self): subscriber_2 = pubsub.Subscriber(self.socket_path) + self.publisher.accept_outstanding_connections() self.publisher.write({"foo": "bar"}) self.assertEqual(self.subscriber.get_latest_message(), {"foo": "bar"}) self.assertEqual(subscriber_2.get_latest_message(), {"foo": "bar"}) @@ -47,6 +49,7 @@ class PubSubTestCase(TestCase): subscribers = [] for i in range(100): subscribers.append(pubsub.Subscriber(self.socket_path)) + self.publisher.accept_outstanding_connections() self.publisher.write({"foo": "bar"}) for subscriber in subscribers: self.assertEqual(subscriber.get_latest_message(), {"foo": "bar"}) @@ -55,3 +58,19 @@ class PubSubTestCase(TestCase): def test_no_subscribers(self): self.subscriber.close() self.publisher.write({"foo": "bar"}) + + def test_cant_accept_connections_with_thread_running(self): + self.publisher.start() + with self.assertRaises(Exception) as e: + self.publisher.accept_outstanding_connections() + self.assertIn( + "Cannot accept connections manually whilst thread is running", + str(e.exception), + ) + + def test_accepts_connections(self): + self.assertEqual(len(self.publisher.connections), 1) + pubsub.Subscriber(self.socket_path) + self.assertEqual(len(self.publisher.connections), 1) + self.publisher.accept_outstanding_connections() + self.assertEqual(len(self.publisher.connections), 2)