diff --git a/ipc_unix/simple.py b/ipc_unix/simple.py index 5a4294a..5aad761 100644 --- a/ipc_unix/simple.py +++ b/ipc_unix/simple.py @@ -7,11 +7,15 @@ import ujson from ipc_unix.utils import read_payload -def send_to(socket_path, data: dict): - with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: - sock.connect(socket_path) - sock.sendall(ujson.dumps(data).encode() + b"\n") - return read_payload(sock)[0] +class Client: + def __init__(self, socket_path): + self.socket_path = socket_path + + def send(self, data: dict): + with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: + sock.connect(self.socket_path) + sock.sendall(ujson.dumps(data).encode() + b"\n") + return read_payload(sock)[0] class RequestHandler(socketserver.BaseRequestHandler): diff --git a/tests/test_simple.py b/tests/test_simple.py index e256cae..25000cb 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -1,7 +1,6 @@ -from functools import partial from unittest import TestCase -from ipc_unix.simple import send_to +from ipc_unix.simple import Client from ipc_unix.utils import BUFFER_SIZE from tests import EchoServer, get_temp_file_path @@ -11,23 +10,23 @@ class SimpleServerTestCase(TestCase): self.socket_path = get_temp_file_path() self.server = EchoServer(self.socket_path) self.server.serve_in_thread() - self.send_to_client = partial(send_to, self.socket_path) + self.client = Client(self.socket_path) def tearDown(self): self.server.close() def test_sending_dict(self): data = {"foo": "bar"} - response = self.send_to_client(data) + response = self.client.send(data) self.assertEqual(response, data) def test_sending_full_buffer(self): data = {"foo" + str(i): i for i in range(BUFFER_SIZE)} - response = self.send_to_client(data) + response = self.client.send(data) self.assertEqual(response, data) def test_multiple_send_to_same_server(self): data = {"foo": "bar"} for _ in range(10): - response = self.send_to_client(data) + response = self.client.send(data) self.assertEqual(response, data)