Force payload to be dict, increase stability and speed
This commit is contained in:
parent
1b69e09276
commit
07affb4c36
6 changed files with 59 additions and 38 deletions
|
@ -2,24 +2,34 @@ import select
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
from ipc_unix.utils import read_payload
|
from ipc_unix.utils import DEFAULT_SOCKET_TIMEOUT, read_payload, socket_has_data
|
||||||
|
|
||||||
|
|
||||||
class Subscriber:
|
class Subscriber:
|
||||||
def __init__(self, socket_path):
|
def __init__(self, socket_path):
|
||||||
self.socket_path = socket_path
|
self.socket_path = socket_path
|
||||||
self.socket = socket.socket(
|
self.socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
|
||||||
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
|
||||||
)
|
|
||||||
self.socket.connect(self.socket_path)
|
self.socket.connect(self.socket_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_data(self):
|
||||||
|
return socket_has_data(self.socket)
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
while True:
|
while True:
|
||||||
yield self.get_message()
|
yield from self.get_message()
|
||||||
|
|
||||||
def get_message(self):
|
def get_messages(self) -> dict:
|
||||||
return read_payload(self.socket)
|
return read_payload(self.socket)
|
||||||
|
|
||||||
|
def flush_data(self):
|
||||||
|
while self.has_data:
|
||||||
|
yield from self.get_messages()
|
||||||
|
|
||||||
|
def get_latest_message(self):
|
||||||
|
data = list(self.flush_data())
|
||||||
|
return data[-1] if data else None
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|
||||||
|
@ -27,11 +37,9 @@ class Subscriber:
|
||||||
class Publisher:
|
class Publisher:
|
||||||
def __init__(self, socket_path):
|
def __init__(self, socket_path):
|
||||||
self.socket_path = socket_path
|
self.socket_path = socket_path
|
||||||
self.master_socket = socket.socket(
|
self.master_socket = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM)
|
||||||
socket.AF_UNIX, type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK
|
|
||||||
)
|
|
||||||
self.master_socket.bind(self.socket_path)
|
self.master_socket.bind(self.socket_path)
|
||||||
self.master_socket.listen()
|
self.master_socket.listen(1)
|
||||||
self.connections = []
|
self.connections = []
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -39,16 +47,16 @@ class Publisher:
|
||||||
self.connections.clear()
|
self.connections.clear()
|
||||||
|
|
||||||
def accept_new_connection(self):
|
def accept_new_connection(self):
|
||||||
readable, _, _ = select.select([self.master_socket], [], [], 1)
|
if socket_has_data(self.master_socket):
|
||||||
|
|
||||||
if self.master_socket in readable:
|
|
||||||
new_socket, _ = self.master_socket.accept()
|
new_socket, _ = self.master_socket.accept()
|
||||||
self.connections.append(new_socket)
|
self.connections.append(new_socket)
|
||||||
|
|
||||||
def write(self, message):
|
def write(self, message: dict):
|
||||||
self.accept_new_connection()
|
self.accept_new_connection()
|
||||||
|
|
||||||
_, writable, errorable = select.select([], self.connections, [], 1)
|
_, writable, errorable = select.select(
|
||||||
|
[], self.connections, [], DEFAULT_SOCKET_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
dead_sockets = []
|
dead_sockets = []
|
||||||
|
|
||||||
|
|
|
@ -6,19 +6,19 @@ import ujson
|
||||||
from ipc_unix.utils import read_payload
|
from ipc_unix.utils import read_payload
|
||||||
|
|
||||||
|
|
||||||
def send_to(socket_path, data):
|
def send_to(socket_path, data: dict):
|
||||||
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
|
||||||
sock.connect(socket_path)
|
sock.connect(socket_path)
|
||||||
sock.sendall(ujson.dumps(data).encode() + b"\n")
|
sock.sendall(ujson.dumps(data).encode() + b"\n")
|
||||||
return read_payload(sock)
|
return read_payload(sock)[0]
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(socketserver.BaseRequestHandler):
|
class RequestHandler(socketserver.BaseRequestHandler):
|
||||||
def handle_request(self, request):
|
def handle_request(self, request: dict):
|
||||||
raise NotImplementedError("Failed to override `handle_request`")
|
raise NotImplementedError("Failed to override `handle_request`")
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
data = read_payload(self.request)
|
data = read_payload(self.request)[0]
|
||||||
response = self.handle_request(data)
|
response = self.handle_request(data)
|
||||||
self.request.sendall(ujson.dumps(response).encode())
|
self.request.sendall(ujson.dumps(response).encode())
|
||||||
|
|
||||||
|
@ -45,5 +45,5 @@ class Server:
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
self.server.server_close()
|
self.server.server_close()
|
||||||
|
|
||||||
def handle_request(self, request):
|
def handle_request(self, request: dict):
|
||||||
raise NotImplementedError("Must override `handle_request`")
|
raise NotImplementedError("Must override `handle_request`")
|
||||||
|
|
|
@ -1,11 +1,28 @@
|
||||||
|
import select
|
||||||
|
|
||||||
import ujson
|
import ujson
|
||||||
|
|
||||||
|
BUFFER_SIZE = 4096
|
||||||
|
DEFAULT_SOCKET_TIMEOUT = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def socket_has_data(sock, timeout=DEFAULT_SOCKET_TIMEOUT) -> bool:
|
||||||
|
readable, _, _ = select.select([sock], [], [], timeout)
|
||||||
|
return sock in readable
|
||||||
|
|
||||||
|
|
||||||
def read_payload(payload):
|
def read_payload(payload):
|
||||||
data = b""
|
data = b""
|
||||||
while b"\n" not in data:
|
while b"\n" not in data:
|
||||||
message = payload.recv(2)
|
if not socket_has_data(payload):
|
||||||
|
break
|
||||||
|
message = payload.recv(BUFFER_SIZE)
|
||||||
if message == b"":
|
if message == b"":
|
||||||
break
|
break
|
||||||
data += message
|
data += message
|
||||||
return ujson.loads(data)
|
parsed_data = []
|
||||||
|
for row in data.split(b"\n"):
|
||||||
|
if not row.strip():
|
||||||
|
continue
|
||||||
|
parsed_data.append(ujson.loads(row))
|
||||||
|
return parsed_data
|
||||||
|
|
|
@ -3,3 +3,4 @@ multi_line_output=3
|
||||||
include_trailing_comma=True
|
include_trailing_comma=True
|
||||||
force_grid_wrap=0
|
force_grid_wrap=0
|
||||||
use_parentheses=True
|
use_parentheses=True
|
||||||
|
line_length=88
|
||||||
|
|
|
@ -16,13 +16,16 @@ class PubSubTestCase(TestCase):
|
||||||
|
|
||||||
def test_transmits(self):
|
def test_transmits(self):
|
||||||
self.publisher.write({"foo": "bar"})
|
self.publisher.write({"foo": "bar"})
|
||||||
response = self.subscriber.get_message()
|
response = self.subscriber.get_latest_message()
|
||||||
self.assertEqual(response, {"foo": "bar"})
|
self.assertEqual(response, {"foo": "bar"})
|
||||||
|
|
||||||
|
def test_no_messages(self):
|
||||||
|
self.assertIsNone(self.subscriber.get_latest_message())
|
||||||
|
self.assertFalse(self.subscriber.has_data)
|
||||||
|
|
||||||
def test_buffers_messages(self):
|
def test_buffers_messages(self):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
self.publisher.write(i)
|
self.publisher.write({"data": i})
|
||||||
messages = []
|
all_messages = self.subscriber.flush_data()
|
||||||
for i in range(5):
|
message_ids = [message["data"] for message in all_messages]
|
||||||
messages.append(self.subscriber.get_message())
|
self.assertEqual(message_ids, [0, 1, 2, 3, 4])
|
||||||
self.assertEqual(messages, [0, 1, 2, 3, 4])
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ from functools import partial
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from ipc_unix.simple import send_to
|
from ipc_unix.simple import send_to
|
||||||
|
from ipc_unix.utils import BUFFER_SIZE
|
||||||
from tests import EchoServer, get_temp_file_path
|
from tests import EchoServer, get_temp_file_path
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,20 +21,11 @@ class SimpleServerTestCase(TestCase):
|
||||||
response = self.send_to_client(data)
|
response = self.send_to_client(data)
|
||||||
self.assertEqual(response, data)
|
self.assertEqual(response, data)
|
||||||
|
|
||||||
def test_sending_string(self):
|
|
||||||
data = "foo"
|
|
||||||
response = self.send_to_client(data)
|
|
||||||
self.assertEqual(response, data)
|
|
||||||
|
|
||||||
def test_sending_full_buffer(self):
|
def test_sending_full_buffer(self):
|
||||||
data = ["foo"] * 4096 # Pad out the buffer
|
data = {"foo" + str(i): i for i in range(BUFFER_SIZE)}
|
||||||
response = self.send_to_client(data)
|
response = self.send_to_client(data)
|
||||||
self.assertEqual(response, data)
|
self.assertEqual(response, data)
|
||||||
|
|
||||||
def test_sending_empty_payload(self):
|
|
||||||
response = self.send_to_client("")
|
|
||||||
self.assertEqual(response, "")
|
|
||||||
|
|
||||||
def test_multiple_send_to_same_server(self):
|
def test_multiple_send_to_same_server(self):
|
||||||
data = {"foo": "bar"}
|
data = {"foo": "bar"}
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
|
|
Reference in a new issue