diff --git a/tcp_nat_proxy.py b/tcp_nat_proxy.py index cd25ada..aed68ca 100644 --- a/tcp_nat_proxy.py +++ b/tcp_nat_proxy.py @@ -1,6 +1,7 @@ import argparse import asyncio from collections import namedtuple +from functools import partial Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"]) @@ -23,22 +24,25 @@ async def pipe(reader, writer): writer.close() -async def create_proxy_pipe(route: Route): - async def handle_client(local_reader, local_writer): - try: - remote_reader, remote_writer = await asyncio.open_connection( - route.destination_host, route.destination_port - ) - await asyncio.gather( - pipe(local_reader, remote_writer), pipe(remote_reader, local_writer) - ) - except ConnectionRefusedError: - print("Connection to {} refused".format(destination_host_display(route))) - pass - finally: - local_writer.close() +async def handle_client(route, local_reader, local_writer): + try: + remote_reader, remote_writer = await asyncio.open_connection( + route.destination_host, route.destination_port + ) + await asyncio.gather( + pipe(local_reader, remote_writer), pipe(remote_reader, local_writer) + ) + except ConnectionRefusedError: + print("Connection to {} refused".format(destination_host_display(route))) + pass + finally: + local_writer.close() - server = await asyncio.start_server(handle_client, "0.0.0.0", route.listen_port) + +async def create_proxy_pipe(route: Route): + server = await asyncio.start_server( + partial(handle_client, route), "0.0.0.0", route.listen_port + ) print( "Routing from {} to {}".format( route.listen_port, destination_host_display(route) @@ -54,6 +58,7 @@ async def main(): ) args = parser.parse_args() servers = [create_proxy_pipe(route) for route in args.route] + print("Starting servers...") await asyncio.gather(*servers)