Force payload to be dict, increase stability and speed

This commit is contained in:
Jake Howard 2018-12-08 01:12:34 +00:00
parent 1b69e09276
commit 07affb4c36
Signed by: jake
GPG key ID: 57AFB45680EDD477
6 changed files with 59 additions and 38 deletions

View file

@ -2,24 +2,34 @@ import select
import socket
import ujson
from ipc_unix.utils import read_payload
from ipc_unix.utils import DEFAULT_SOCKET_TIMEOUT, read_payload, socket_has_data
class Subscriber:
def __init__(self, socket_path):
self.socket_path = socket_path
self.socket = socket.socket(
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
)
self.socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
self.socket.connect(self.socket_path)
@property
def has_data(self):
return socket_has_data(self.socket)
def listen(self):
while True:
yield self.get_message()
yield from self.get_message()
def get_message(self):
def get_messages(self) -> dict:
return read_payload(self.socket)
def flush_data(self):
while self.has_data:
yield from self.get_messages()
def get_latest_message(self):
data = list(self.flush_data())
return data[-1] if data else None
def close(self):
self.socket.close()
@ -27,11 +37,9 @@ class Subscriber:
class Publisher:
def __init__(self, socket_path):
self.socket_path = socket_path
self.master_socket = socket.socket(
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
)
self.master_socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
self.master_socket.bind(self.socket_path)
self.master_socket.listen()
self.master_socket.listen(1)
self.connections = []
def close(self):
@ -39,16 +47,16 @@ class Publisher:
self.connections.clear()
def accept_new_connection(self):
readable, _, _ = select.select([self.master_socket], [], [], 1)
if self.master_socket in readable:
if socket_has_data(self.master_socket):
new_socket, _ = self.master_socket.accept()
self.connections.append(new_socket)
def write(self, message):
def write(self, message: dict):
self.accept_new_connection()
_, writable, errorable = select.select([], self.connections, [], 1)
_, writable, errorable = select.select(
[], self.connections, [], DEFAULT_SOCKET_TIMEOUT
)
dead_sockets = []

View file

@ -6,19 +6,19 @@ import ujson
from ipc_unix.utils import read_payload
def send_to(socket_path, data):
def send_to(socket_path, data: dict):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
sock.connect(socket_path)
sock.sendall(ujson.dumps(data).encode() + b"\n")
return read_payload(sock)
return read_payload(sock)[0]
class RequestHandler(socketserver.BaseRequestHandler):
def handle_request(self, request):
def handle_request(self, request: dict):
raise NotImplementedError("Failed to override `handle_request`")
def handle(self):
data = read_payload(self.request)
data = read_payload(self.request)[0]
response = self.handle_request(data)
self.request.sendall(ujson.dumps(response).encode())
@ -45,5 +45,5 @@ class Server:
self.shutdown()
self.server.server_close()
def handle_request(self, request):
def handle_request(self, request: dict):
raise NotImplementedError("Must override `handle_request`")

View file

@ -1,11 +1,28 @@
import select
import ujson
BUFFER_SIZE = 4096
DEFAULT_SOCKET_TIMEOUT = 0.1
def socket_has_data(sock, timeout=DEFAULT_SOCKET_TIMEOUT) -> bool:
readable, _, _ = select.select([sock], [], [], timeout)
return sock in readable
def read_payload(payload):
data = b""
while b"\n" not in data:
message = payload.recv(2)
if not socket_has_data(payload):
break
message = payload.recv(BUFFER_SIZE)
if message == b"":
break
data += message
return ujson.loads(data)
parsed_data = []
for row in data.split(b"\n"):
if not row.strip():
continue
parsed_data.append(ujson.loads(row))
return parsed_data

View file

@ -3,3 +3,4 @@ multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88

View file

@ -16,13 +16,16 @@ class PubSubTestCase(TestCase):
def test_transmits(self):
self.publisher.write({"foo": "bar"})
response = self.subscriber.get_message()
response = self.subscriber.get_latest_message()
self.assertEqual(response, {"foo": "bar"})
def test_no_messages(self):
self.assertIsNone(self.subscriber.get_latest_message())
self.assertFalse(self.subscriber.has_data)
def test_buffers_messages(self):
for i in range(5):
self.publisher.write(i)
messages = []
for i in range(5):
messages.append(self.subscriber.get_message())
self.assertEqual(messages, [0, 1, 2, 3, 4])
self.publisher.write({"data": i})
all_messages = self.subscriber.flush_data()
message_ids = [message["data"] for message in all_messages]
self.assertEqual(message_ids, [0, 1, 2, 3, 4])

View file

@ -2,6 +2,7 @@ from functools import partial
from unittest import TestCase
from ipc_unix.simple import send_to
from ipc_unix.utils import BUFFER_SIZE
from tests import EchoServer, get_temp_file_path
@ -20,20 +21,11 @@ class SimpleServerTestCase(TestCase):
response = self.send_to_client(data)
self.assertEqual(response, data)
def test_sending_string(self):
data = "foo"
response = self.send_to_client(data)
self.assertEqual(response, data)
def test_sending_full_buffer(self):
data = ["foo"] * 4096 # Pad out the buffer
data = {"foo" + str(i): i for i in range(BUFFER_SIZE)}
response = self.send_to_client(data)
self.assertEqual(response, data)
def test_sending_empty_payload(self):
response = self.send_to_client("")
self.assertEqual(response, "")
def test_multiple_send_to_same_server(self):
data = {"foo": "bar"}
for _ in range(10):