Force payload to be dict, increase stability and speed
This commit is contained in:
parent
1b69e09276
commit
07affb4c36
6 changed files with 59 additions and 38 deletions
|
@ -2,24 +2,34 @@ import select
|
|||
import socket
|
||||
|
||||
import ujson
|
||||
from ipc_unix.utils import read_payload
|
||||
from ipc_unix.utils import DEFAULT_SOCKET_TIMEOUT, read_payload, socket_has_data
|
||||
|
||||
|
||||
class Subscriber:
|
||||
def __init__(self, socket_path):
|
||||
self.socket_path = socket_path
|
||||
self.socket = socket.socket(
|
||||
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
||||
)
|
||||
self.socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
|
||||
self.socket.connect(self.socket_path)
|
||||
|
||||
@property
|
||||
def has_data(self):
|
||||
return socket_has_data(self.socket)
|
||||
|
||||
def listen(self):
|
||||
while True:
|
||||
yield self.get_message()
|
||||
yield from self.get_message()
|
||||
|
||||
def get_message(self):
|
||||
def get_messages(self) -> dict:
|
||||
return read_payload(self.socket)
|
||||
|
||||
def flush_data(self):
|
||||
while self.has_data:
|
||||
yield from self.get_messages()
|
||||
|
||||
def get_latest_message(self):
|
||||
data = list(self.flush_data())
|
||||
return data[-1] if data else None
|
||||
|
||||
def close(self):
|
||||
self.socket.close()
|
||||
|
||||
|
@ -27,11 +37,9 @@ class Subscriber:
|
|||
class Publisher:
|
||||
def __init__(self, socket_path):
|
||||
self.socket_path = socket_path
|
||||
self.master_socket = socket.socket(
|
||||
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
||||
)
|
||||
self.master_socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
|
||||
self.master_socket.bind(self.socket_path)
|
||||
self.master_socket.listen()
|
||||
self.master_socket.listen(1)
|
||||
self.connections = []
|
||||
|
||||
def close(self):
|
||||
|
@ -39,16 +47,16 @@ class Publisher:
|
|||
self.connections.clear()
|
||||
|
||||
def accept_new_connection(self):
|
||||
readable, _, _ = select.select([self.master_socket], [], [], 1)
|
||||
|
||||
if self.master_socket in readable:
|
||||
if socket_has_data(self.master_socket):
|
||||
new_socket, _ = self.master_socket.accept()
|
||||
self.connections.append(new_socket)
|
||||
|
||||
def write(self, message):
|
||||
def write(self, message: dict):
|
||||
self.accept_new_connection()
|
||||
|
||||
_, writable, errorable = select.select([], self.connections, [], 1)
|
||||
_, writable, errorable = select.select(
|
||||
[], self.connections, [], DEFAULT_SOCKET_TIMEOUT
|
||||
)
|
||||
|
||||
dead_sockets = []
|
||||
|
||||
|
|
|
@ -6,19 +6,19 @@ import ujson
|
|||
from ipc_unix.utils import read_payload
|
||||
|
||||
|
||||
def send_to(socket_path, data):
|
||||
def send_to(socket_path, data: dict):
|
||||
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
||||
sock.connect(socket_path)
|
||||
sock.sendall(ujson.dumps(data).encode() + b"\n")
|
||||
return read_payload(sock)
|
||||
return read_payload(sock)[0]
|
||||
|
||||
|
||||
class RequestHandler(socketserver.BaseRequestHandler):
|
||||
def handle_request(self, request):
|
||||
def handle_request(self, request: dict):
|
||||
raise NotImplementedError("Failed to override `handle_request`")
|
||||
|
||||
def handle(self):
|
||||
data = read_payload(self.request)
|
||||
data = read_payload(self.request)[0]
|
||||
response = self.handle_request(data)
|
||||
self.request.sendall(ujson.dumps(response).encode())
|
||||
|
||||
|
@ -45,5 +45,5 @@ class Server:
|
|||
self.shutdown()
|
||||
self.server.server_close()
|
||||
|
||||
def handle_request(self, request):
|
||||
def handle_request(self, request: dict):
|
||||
raise NotImplementedError("Must override `handle_request`")
|
||||
|
|
|
@ -1,11 +1,28 @@
|
|||
import select
|
||||
|
||||
import ujson
|
||||
|
||||
BUFFER_SIZE = 4096
|
||||
DEFAULT_SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
def socket_has_data(sock, timeout=DEFAULT_SOCKET_TIMEOUT) -> bool:
|
||||
readable, _, _ = select.select([sock], [], [], timeout)
|
||||
return sock in readable
|
||||
|
||||
|
||||
def read_payload(payload):
|
||||
data = b""
|
||||
while b"\n" not in data:
|
||||
message = payload.recv(2)
|
||||
if not socket_has_data(payload):
|
||||
break
|
||||
message = payload.recv(BUFFER_SIZE)
|
||||
if message == b"":
|
||||
break
|
||||
data += message
|
||||
return ujson.loads(data)
|
||||
parsed_data = []
|
||||
for row in data.split(b"\n"):
|
||||
if not row.strip():
|
||||
continue
|
||||
parsed_data.append(ujson.loads(row))
|
||||
return parsed_data
|
||||
|
|
|
@ -3,3 +3,4 @@ multi_line_output=3
|
|||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
||||
|
|
|
@ -16,13 +16,16 @@ class PubSubTestCase(TestCase):
|
|||
|
||||
def test_transmits(self):
|
||||
self.publisher.write({"foo": "bar"})
|
||||
response = self.subscriber.get_message()
|
||||
response = self.subscriber.get_latest_message()
|
||||
self.assertEqual(response, {"foo": "bar"})
|
||||
|
||||
def test_no_messages(self):
|
||||
self.assertIsNone(self.subscriber.get_latest_message())
|
||||
self.assertFalse(self.subscriber.has_data)
|
||||
|
||||
def test_buffers_messages(self):
|
||||
for i in range(5):
|
||||
self.publisher.write(i)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
messages.append(self.subscriber.get_message())
|
||||
self.assertEqual(messages, [0, 1, 2, 3, 4])
|
||||
self.publisher.write({"data": i})
|
||||
all_messages = self.subscriber.flush_data()
|
||||
message_ids = [message["data"] for message in all_messages]
|
||||
self.assertEqual(message_ids, [0, 1, 2, 3, 4])
|
||||
|
|
|
@ -2,6 +2,7 @@ from functools import partial
|
|||
from unittest import TestCase
|
||||
|
||||
from ipc_unix.simple import send_to
|
||||
from ipc_unix.utils import BUFFER_SIZE
|
||||
from tests import EchoServer, get_temp_file_path
|
||||
|
||||
|
||||
|
@ -20,20 +21,11 @@ class SimpleServerTestCase(TestCase):
|
|||
response = self.send_to_client(data)
|
||||
self.assertEqual(response, data)
|
||||
|
||||
def test_sending_string(self):
|
||||
data = "foo"
|
||||
response = self.send_to_client(data)
|
||||
self.assertEqual(response, data)
|
||||
|
||||
def test_sending_full_buffer(self):
|
||||
data = ["foo"] * 4096 # Pad out the buffer
|
||||
data = {"foo" + str(i): i for i in range(BUFFER_SIZE)}
|
||||
response = self.send_to_client(data)
|
||||
self.assertEqual(response, data)
|
||||
|
||||
def test_sending_empty_payload(self):
|
||||
response = self.send_to_client("")
|
||||
self.assertEqual(response, "")
|
||||
|
||||
def test_multiple_send_to_same_server(self):
|
||||
data = {"foo": "bar"}
|
||||
for _ in range(10):
|
||||
|
|
Reference in a new issue