Use background thread to accept new connections
This commit is contained in:
parent
886882a079
commit
abd20a6b4b
2 changed files with 43 additions and 3 deletions
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
|
import threading
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
from ipc_unix.utils import read_payload, socket_has_data
|
from ipc_unix.utils import read_payload, socket_has_data
|
||||||
|
@ -46,22 +47,42 @@ class Publisher:
|
||||||
self.master_socket.bind(self.socket_path)
|
self.master_socket.bind(self.socket_path)
|
||||||
self.master_socket.listen()
|
self.master_socket.listen()
|
||||||
self.connections = []
|
self.connections = []
|
||||||
|
self.accepting_new_connections = threading.Event()
|
||||||
|
self.accepting_new_connections.set()
|
||||||
|
self.new_connections_thread = threading.Thread(
|
||||||
|
target=self._accept_new_connections
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.accepting_new_connections.set()
|
||||||
|
self.new_connections_thread.start()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
self.accepting_new_connections.clear()
|
||||||
|
if self.new_connections_thread.is_alive():
|
||||||
|
self.new_connections_thread.join()
|
||||||
self.master_socket.close()
|
self.master_socket.close()
|
||||||
for connection in self.connections:
|
for connection in self.connections:
|
||||||
connection.close()
|
connection.close()
|
||||||
self.connections.clear()
|
self.connections.clear()
|
||||||
os.remove(self.socket_path)
|
os.remove(self.socket_path)
|
||||||
|
|
||||||
def accept_new_connection(self):
|
def accept_outstanding_connections(self):
|
||||||
|
if self.new_connections_thread.is_alive():
|
||||||
|
raise Exception(
|
||||||
|
"Cannot accept connections manually whilst thread is running"
|
||||||
|
)
|
||||||
while socket_has_data(self.master_socket):
|
while socket_has_data(self.master_socket):
|
||||||
new_socket, _ = self.master_socket.accept()
|
new_socket, _ = self.master_socket.accept()
|
||||||
self.connections.append(new_socket)
|
self.connections.append(new_socket)
|
||||||
|
|
||||||
def write(self, message: dict):
|
def _accept_new_connections(self):
|
||||||
self.accept_new_connection()
|
while self.accepting_new_connections.is_set():
|
||||||
|
if socket_has_data(self.master_socket):
|
||||||
|
new_socket, _ = self.master_socket.accept()
|
||||||
|
self.connections.append(new_socket)
|
||||||
|
|
||||||
|
def write(self, message: dict):
|
||||||
_, writable, errorable = select.select([], self.connections, [], 1)
|
_, writable, errorable = select.select([], self.connections, [], 1)
|
||||||
|
|
||||||
dead_sockets = []
|
dead_sockets = []
|
||||||
|
|
|
@ -9,6 +9,7 @@ class PubSubTestCase(TestCase):
|
||||||
self.socket_path = get_temp_file_path()
|
self.socket_path = get_temp_file_path()
|
||||||
self.publisher = pubsub.Publisher(self.socket_path)
|
self.publisher = pubsub.Publisher(self.socket_path)
|
||||||
self.subscriber = pubsub.Subscriber(self.socket_path)
|
self.subscriber = pubsub.Subscriber(self.socket_path)
|
||||||
|
self.publisher.accept_outstanding_connections()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.publisher.close()
|
self.publisher.close()
|
||||||
|
@ -39,6 +40,7 @@ class PubSubTestCase(TestCase):
|
||||||
|
|
||||||
def test_multiple_subscribers(self):
|
def test_multiple_subscribers(self):
|
||||||
subscriber_2 = pubsub.Subscriber(self.socket_path)
|
subscriber_2 = pubsub.Subscriber(self.socket_path)
|
||||||
|
self.publisher.accept_outstanding_connections()
|
||||||
self.publisher.write({"foo": "bar"})
|
self.publisher.write({"foo": "bar"})
|
||||||
self.assertEqual(self.subscriber.get_latest_message(), {"foo": "bar"})
|
self.assertEqual(self.subscriber.get_latest_message(), {"foo": "bar"})
|
||||||
self.assertEqual(subscriber_2.get_latest_message(), {"foo": "bar"})
|
self.assertEqual(subscriber_2.get_latest_message(), {"foo": "bar"})
|
||||||
|
@ -47,6 +49,7 @@ class PubSubTestCase(TestCase):
|
||||||
subscribers = []
|
subscribers = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
subscribers.append(pubsub.Subscriber(self.socket_path))
|
subscribers.append(pubsub.Subscriber(self.socket_path))
|
||||||
|
self.publisher.accept_outstanding_connections()
|
||||||
self.publisher.write({"foo": "bar"})
|
self.publisher.write({"foo": "bar"})
|
||||||
for subscriber in subscribers:
|
for subscriber in subscribers:
|
||||||
self.assertEqual(subscriber.get_latest_message(), {"foo": "bar"})
|
self.assertEqual(subscriber.get_latest_message(), {"foo": "bar"})
|
||||||
|
@ -55,3 +58,19 @@ class PubSubTestCase(TestCase):
|
||||||
def test_no_subscribers(self):
|
def test_no_subscribers(self):
|
||||||
self.subscriber.close()
|
self.subscriber.close()
|
||||||
self.publisher.write({"foo": "bar"})
|
self.publisher.write({"foo": "bar"})
|
||||||
|
|
||||||
|
def test_cant_accept_connections_with_thread_running(self):
|
||||||
|
self.publisher.start()
|
||||||
|
with self.assertRaises(Exception) as e:
|
||||||
|
self.publisher.accept_outstanding_connections()
|
||||||
|
self.assertIn(
|
||||||
|
"Cannot accept connections manually whilst thread is running",
|
||||||
|
str(e.exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_accepts_connections(self):
|
||||||
|
self.assertEqual(len(self.publisher.connections), 1)
|
||||||
|
pubsub.Subscriber(self.socket_path)
|
||||||
|
self.assertEqual(len(self.publisher.connections), 1)
|
||||||
|
self.publisher.accept_outstanding_connections()
|
||||||
|
self.assertEqual(len(self.publisher.connections), 2)
|
||||||
|
|
Reference in a new issue