1
Fork 0

Don't close over function

This commit is contained in:
Jake Howard 2019-01-01 19:58:47 +00:00
parent dfbaf808fe
commit a6a4069158
Signed by: jake
GPG key ID: 57AFB45680EDD477

View file

@ -1,6 +1,7 @@
import argparse import argparse
import asyncio import asyncio
from collections import namedtuple from collections import namedtuple
from functools import partial
Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"]) Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"])
@ -23,22 +24,25 @@ async def pipe(reader, writer):
writer.close() writer.close()
async def create_proxy_pipe(route: Route): async def handle_client(route, local_reader, local_writer):
async def handle_client(local_reader, local_writer): try:
try: remote_reader, remote_writer = await asyncio.open_connection(
remote_reader, remote_writer = await asyncio.open_connection( route.destination_host, route.destination_port
route.destination_host, route.destination_port )
) await asyncio.gather(
await asyncio.gather( pipe(local_reader, remote_writer), pipe(remote_reader, local_writer)
pipe(local_reader, remote_writer), pipe(remote_reader, local_writer) )
) except ConnectionRefusedError:
except ConnectionRefusedError: print("Connection to {} refused".format(destination_host_display(route)))
print("Connection to {} refused".format(destination_host_display(route))) pass
pass finally:
finally: local_writer.close()
local_writer.close()
server = await asyncio.start_server(handle_client, "0.0.0.0", route.listen_port)
async def create_proxy_pipe(route: Route):
server = await asyncio.start_server(
partial(handle_client, route), "0.0.0.0", route.listen_port
)
print( print(
"Routing from {} to {}".format( "Routing from {} to {}".format(
route.listen_port, destination_host_display(route) route.listen_port, destination_host_display(route)
@ -54,6 +58,7 @@ async def main():
) )
args = parser.parse_args() args = parser.parse_args()
servers = [create_proxy_pipe(route) for route in args.route] servers = [create_proxy_pipe(route) for route in args.route]
print("Starting servers...")
await asyncio.gather(*servers) await asyncio.gather(*servers)