Refactor simple server client to class

This commit is contained in:
Jake Howard 2018-12-08 14:23:16 +00:00
parent abd20a6b4b
commit ead62d5671
Signed by: jake
GPG key ID: 57AFB45680EDD477
2 changed files with 14 additions and 11 deletions

View file

@ -7,11 +7,15 @@ import ujson
from ipc_unix.utils import read_payload from ipc_unix.utils import read_payload
def send_to(socket_path, data: dict): class Client:
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock: def __init__(self, socket_path):
sock.connect(socket_path) self.socket_path = socket_path
sock.sendall(ujson.dumps(data).encode() + b"\n")
return read_payload(sock)[0] def send(self, data: dict):
with socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) as sock:
sock.connect(self.socket_path)
sock.sendall(ujson.dumps(data).encode() + b"\n")
return read_payload(sock)[0]
class RequestHandler(socketserver.BaseRequestHandler): class RequestHandler(socketserver.BaseRequestHandler):

View file

@ -1,7 +1,6 @@
from functools import partial
from unittest import TestCase from unittest import TestCase
from ipc_unix.simple import send_to from ipc_unix.simple import Client
from ipc_unix.utils import BUFFER_SIZE from ipc_unix.utils import BUFFER_SIZE
from tests import EchoServer, get_temp_file_path from tests import EchoServer, get_temp_file_path
@ -11,23 +10,23 @@ class SimpleServerTestCase(TestCase):
self.socket_path = get_temp_file_path() self.socket_path = get_temp_file_path()
self.server = EchoServer(self.socket_path) self.server = EchoServer(self.socket_path)
self.server.serve_in_thread() self.server.serve_in_thread()
self.send_to_client = partial(send_to, self.socket_path) self.client = Client(self.socket_path)
def tearDown(self): def tearDown(self):
self.server.close() self.server.close()
def test_sending_dict(self): def test_sending_dict(self):
data = {"foo": "bar"} data = {"foo": "bar"}
response = self.send_to_client(data) response = self.client.send(data)
self.assertEqual(response, data) self.assertEqual(response, data)
def test_sending_full_buffer(self): def test_sending_full_buffer(self):
data = {"foo" + str(i): i for i in range(BUFFER_SIZE)} data = {"foo" + str(i): i for i in range(BUFFER_SIZE)}
response = self.send_to_client(data) response = self.client.send(data)
self.assertEqual(response, data) self.assertEqual(response, data)
def test_multiple_send_to_same_server(self): def test_multiple_send_to_same_server(self):
data = {"foo": "bar"} data = {"foo": "bar"}
for _ in range(10): for _ in range(10):
response = self.send_to_client(data) response = self.client.send(data)
self.assertEqual(response, data) self.assertEqual(response, data)