Don't close over function
This commit is contained in:
parent
dfbaf808fe
commit
a6a4069158
1 changed files with 20 additions and 15 deletions
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"])
|
Route = namedtuple("Route", ["listen_port", "destination_host", "destination_port"])
|
||||||
|
|
||||||
|
@ -23,22 +24,25 @@ async def pipe(reader, writer):
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_proxy_pipe(route: Route):
|
async def handle_client(route, local_reader, local_writer):
|
||||||
async def handle_client(local_reader, local_writer):
|
try:
|
||||||
try:
|
remote_reader, remote_writer = await asyncio.open_connection(
|
||||||
remote_reader, remote_writer = await asyncio.open_connection(
|
route.destination_host, route.destination_port
|
||||||
route.destination_host, route.destination_port
|
)
|
||||||
)
|
await asyncio.gather(
|
||||||
await asyncio.gather(
|
pipe(local_reader, remote_writer), pipe(remote_reader, local_writer)
|
||||||
pipe(local_reader, remote_writer), pipe(remote_reader, local_writer)
|
)
|
||||||
)
|
except ConnectionRefusedError:
|
||||||
except ConnectionRefusedError:
|
print("Connection to {} refused".format(destination_host_display(route)))
|
||||||
print("Connection to {} refused".format(destination_host_display(route)))
|
pass
|
||||||
pass
|
finally:
|
||||||
finally:
|
local_writer.close()
|
||||||
local_writer.close()
|
|
||||||
|
|
||||||
server = await asyncio.start_server(handle_client, "0.0.0.0", route.listen_port)
|
|
||||||
|
async def create_proxy_pipe(route: Route):
|
||||||
|
server = await asyncio.start_server(
|
||||||
|
partial(handle_client, route), "0.0.0.0", route.listen_port
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"Routing from {} to {}".format(
|
"Routing from {} to {}".format(
|
||||||
route.listen_port, destination_host_display(route)
|
route.listen_port, destination_host_display(route)
|
||||||
|
@ -54,6 +58,7 @@ async def main():
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
servers = [create_proxy_pipe(route) for route in args.route]
|
servers = [create_proxy_pipe(route) for route in args.route]
|
||||||
|
print("Starting servers...")
|
||||||
await asyncio.gather(*servers)
|
await asyncio.gather(*servers)
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in a new issue