Add pub sub server
This commit is contained in:
parent
f6c37a64a4
commit
aba697bfb2
5 changed files with 100 additions and 6 deletions
66
ipc_unix/pubsub.py
Normal file
66
ipc_unix/pubsub.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import select
|
||||||
|
import socket
|
||||||
|
|
||||||
|
import ujson
|
||||||
|
from ipc_unix.utils import read_payload
|
||||||
|
|
||||||
|
|
||||||
|
class Subscriber:
|
||||||
|
def __init__(self, socket_path):
|
||||||
|
self.socket_path = socket_path
|
||||||
|
self.socket = socket.socket(
|
||||||
|
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
||||||
|
)
|
||||||
|
self.socket.connect(self.socket_path)
|
||||||
|
|
||||||
|
def listen(self):
|
||||||
|
while True:
|
||||||
|
yield self.get_message()
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
return read_payload(self.socket)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.socket.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Publisher:
|
||||||
|
def __init__(self, socket_path):
|
||||||
|
self.socket_path = socket_path
|
||||||
|
self.master_socket = socket.socket(
|
||||||
|
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
||||||
|
)
|
||||||
|
self.master_socket.bind(self.socket_path)
|
||||||
|
self.master_socket.listen()
|
||||||
|
self.connections = []
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.master_socket.close()
|
||||||
|
self.connections.clear()
|
||||||
|
|
||||||
|
def accept_new_connection(self):
|
||||||
|
readable, _, _ = select.select([self.master_socket], [], [], 1)
|
||||||
|
|
||||||
|
if self.master_socket in readable:
|
||||||
|
new_socket, _ = self.master_socket.accept()
|
||||||
|
self.connections.append(new_socket)
|
||||||
|
|
||||||
|
def write(self, message):
|
||||||
|
self.accept_new_connection()
|
||||||
|
|
||||||
|
_, writable, errorable = select.select([], self.connections, [], 1)
|
||||||
|
|
||||||
|
dead_sockets = []
|
||||||
|
|
||||||
|
if writable:
|
||||||
|
data = ujson.dumps(message).encode() + b"\n"
|
||||||
|
for sock in writable:
|
||||||
|
try:
|
||||||
|
sock.send(data)
|
||||||
|
except BrokenPipeError:
|
||||||
|
dead_sockets.append(sock)
|
||||||
|
|
||||||
|
for sock in dead_sockets:
|
||||||
|
if sock in self.connections:
|
||||||
|
self.connections.remove(sock)
|
||||||
|
sock.close()
|
|
@ -4,7 +4,7 @@ import ujson
|
||||||
def read_payload(payload):
|
def read_payload(payload):
|
||||||
data = b""
|
data = b""
|
||||||
while b"\n" not in data:
|
while b"\n" not in data:
|
||||||
message = payload.recv(4096)
|
message = payload.recv(2)
|
||||||
if message == b"":
|
if message == b"":
|
||||||
break
|
break
|
||||||
data += message
|
data += message
|
||||||
|
|
|
@ -8,7 +8,7 @@ class EchoServer(server.Server):
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
def get_random_path() -> str:
|
def get_temp_file_path() -> str:
|
||||||
_, temp_file_path = tempfile.mkstemp()
|
_, temp_file_path = tempfile.mkstemp()
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
return temp_file_path
|
return temp_file_path
|
||||||
|
|
28
tests/test_pubsub.py
Normal file
28
tests/test_pubsub.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
from ipc_unix import pubsub
|
||||||
|
from tests import get_temp_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class PubSubTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.socket_path = get_temp_file_path()
|
||||||
|
self.publisher = pubsub.Publisher(self.socket_path)
|
||||||
|
self.subscriber = pubsub.Subscriber(self.socket_path)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.publisher.close()
|
||||||
|
self.subscriber.close()
|
||||||
|
|
||||||
|
def test_transmits(self):
|
||||||
|
self.publisher.write({"foo": "bar"})
|
||||||
|
response = self.subscriber.get_message()
|
||||||
|
self.assertEqual(response, {"foo": "bar"})
|
||||||
|
|
||||||
|
def test_buffers_messages(self):
|
||||||
|
for i in range(5):
|
||||||
|
self.publisher.write(i)
|
||||||
|
messages = []
|
||||||
|
for i in range(5):
|
||||||
|
messages.append(self.subscriber.get_message())
|
||||||
|
self.assertEqual(messages, [0, 1, 2, 3, 4])
|
|
@ -2,12 +2,12 @@ from functools import partial
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from ipc_unix import client
|
from ipc_unix import client
|
||||||
from tests import EchoServer, get_random_path
|
from tests import EchoServer, get_temp_file_path
|
||||||
|
|
||||||
|
|
||||||
class BasicServerTestCase(TestCase):
|
class BasicServerTestCase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.socket_path = get_random_path()
|
self.socket_path = get_temp_file_path()
|
||||||
self.server = EchoServer(self.socket_path)
|
self.server = EchoServer(self.socket_path)
|
||||||
self.server.serve_in_thread()
|
self.server.serve_in_thread()
|
||||||
self.send_to_client = partial(client.send_to, self.socket_path)
|
self.send_to_client = partial(client.send_to, self.socket_path)
|
||||||
|
@ -20,8 +20,8 @@ class BasicServerTestCase(TestCase):
|
||||||
response = self.send_to_client(data)
|
response = self.send_to_client(data)
|
||||||
self.assertEqual(response, data)
|
self.assertEqual(response, data)
|
||||||
|
|
||||||
def test_sending_array(self):
|
def test_sending_string(self):
|
||||||
data = ["foo", "bar"]
|
data = "foo"
|
||||||
response = self.send_to_client(data)
|
response = self.send_to_client(data)
|
||||||
self.assertEqual(response, data)
|
self.assertEqual(response, data)
|
||||||
|
|
||||||
|
|
Reference in a new issue