Don't close over function
This commit is contained in:
parent
dfbaf808fe
commit
a6a4069158
1 changed files with 20 additions and 15 deletions
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"])
|
||||
|
||||
|
@ -23,22 +24,25 @@ async def pipe(reader, writer):
|
|||
writer.close()
|
||||
|
||||
|
||||
async def create_proxy_pipe(route: Route):
|
||||
async def handle_client(local_reader, local_writer):
|
||||
try:
|
||||
remote_reader, remote_writer = await asyncio.open_connection(
|
||||
route.destination_host, route.destination_port
|
||||
)
|
||||
await asyncio.gather(
|
||||
pipe(local_reader, remote_writer), pipe(remote_reader, local_writer)
|
||||
)
|
||||
except ConnectionRefusedError:
|
||||
print("Connection to {} refused".format(destination_host_display(route)))
|
||||
pass
|
||||
finally:
|
||||
local_writer.close()
|
||||
async def handle_client(route, local_reader, local_writer):
|
||||
try:
|
||||
remote_reader, remote_writer = await asyncio.open_connection(
|
||||
route.destination_host, route.destination_port
|
||||
)
|
||||
await asyncio.gather(
|
||||
pipe(local_reader, remote_writer), pipe(remote_reader, local_writer)
|
||||
)
|
||||
except ConnectionRefusedError:
|
||||
print("Connection to {} refused".format(destination_host_display(route)))
|
||||
pass
|
||||
finally:
|
||||
local_writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle_client, "0.0.0.0", route.listen_port)
|
||||
|
||||
async def create_proxy_pipe(route: Route):
|
||||
server = await asyncio.start_server(
|
||||
partial(handle_client, route), "0.0.0.0", route.listen_port
|
||||
)
|
||||
print(
|
||||
"Routing from {} to {}".format(
|
||||
route.listen_port, destination_host_display(route)
|
||||
|
@ -54,6 +58,7 @@ async def main():
|
|||
)
|
||||
args = parser.parse_args()
|
||||
servers = [create_proxy_pipe(route) for route in args.route]
|
||||
print("Starting servers...")
|
||||
await asyncio.gather(*servers)
|
||||
|
||||
|
||||
|
|
Reference in a new issue