add async dns to tcp relay
This commit is contained in:
parent
5c274a1bc7
commit
bcdc1e9671
2 changed files with 77 additions and 38 deletions
|
@ -173,8 +173,8 @@ def parse_response(data):
|
||||||
res_tc = header[1] & 2
|
res_tc = header[1] & 2
|
||||||
# res_ra = header[2] & 128
|
# res_ra = header[2] & 128
|
||||||
res_rcode = header[2] & 15
|
res_rcode = header[2] & 15
|
||||||
assert res_tc == 0
|
# assert res_tc == 0
|
||||||
assert res_rcode in [0, 3]
|
# assert res_rcode in [0, 3]
|
||||||
res_qdcount = header[3]
|
res_qdcount = header[3]
|
||||||
res_ancount = header[4]
|
res_ancount = header[4]
|
||||||
res_nscount = header[5]
|
res_nscount = header[5]
|
||||||
|
@ -308,7 +308,11 @@ class DNSResolver(object):
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
if self._cb_to_hostname.__contains__(callback):
|
if self._cb_to_hostname.__contains__(callback):
|
||||||
del self._cb_to_hostname[callback]
|
del self._cb_to_hostname[callback]
|
||||||
|
if ip:
|
||||||
callback((hostname, ip), None)
|
callback((hostname, ip), None)
|
||||||
|
else:
|
||||||
|
callback((hostname, None),
|
||||||
|
Exception('unknown hostname %s' % hostname))
|
||||||
if self._hostname_to_cb.__contains__(hostname):
|
if self._hostname_to_cb.__contains__(hostname):
|
||||||
del self._hostname_to_cb[hostname]
|
del self._hostname_to_cb[hostname]
|
||||||
if self._hostname_status.__contains__(hostname):
|
if self._hostname_status.__contains__(hostname):
|
||||||
|
@ -329,6 +333,7 @@ class DNSResolver(object):
|
||||||
self._hostname_status[hostname] = STATUS_IPV6
|
self._hostname_status[hostname] = STATUS_IPV6
|
||||||
self._send_req(hostname, QTYPE_AAAA)
|
self._send_req(hostname, QTYPE_AAAA)
|
||||||
else:
|
else:
|
||||||
|
if ip:
|
||||||
self._cache[hostname] = ip
|
self._cache[hostname] = ip
|
||||||
self._call_callback(hostname, ip)
|
self._call_callback(hostname, ip)
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ CMD_UDP_ASSOCIATE = 3
|
||||||
STAGE_INIT = 0
|
STAGE_INIT = 0
|
||||||
STAGE_HELLO = 1
|
STAGE_HELLO = 1
|
||||||
STAGE_UDP_ASSOC = 2
|
STAGE_UDP_ASSOC = 2
|
||||||
|
STAGE_DNS = 3
|
||||||
STAGE_REPLY = 4
|
STAGE_REPLY = 4
|
||||||
STAGE_STREAM = 5
|
STAGE_STREAM = 5
|
||||||
STAGE_DESTROYED = -1
|
STAGE_DESTROYED = -1
|
||||||
|
@ -75,13 +76,14 @@ BUF_SIZE = 8 * 1024
|
||||||
|
|
||||||
class TCPRelayHandler(object):
|
class TCPRelayHandler(object):
|
||||||
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
|
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
|
||||||
is_local):
|
dns_resolver, is_local):
|
||||||
self._server = server
|
self._server = server
|
||||||
self._fd_to_handlers = fd_to_handlers
|
self._fd_to_handlers = fd_to_handlers
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._local_sock = local_sock
|
self._local_sock = local_sock
|
||||||
self._remote_sock = None
|
self._remote_sock = None
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._dns_resolver = dns_resolver
|
||||||
self._is_local = is_local
|
self._is_local = is_local
|
||||||
self._stage = STAGE_INIT
|
self._stage = STAGE_INIT
|
||||||
self._encryptor = encrypt.Encryptor(config['password'],
|
self._encryptor = encrypt.Encryptor(config['password'],
|
||||||
|
@ -239,51 +241,81 @@ class TCPRelayHandler(object):
|
||||||
addrtype, remote_addr, remote_port, header_length = header_result
|
addrtype, remote_addr, remote_port, header_length = header_result
|
||||||
logging.info('connecting %s:%d' % (remote_addr, remote_port))
|
logging.info('connecting %s:%d' % (remote_addr, remote_port))
|
||||||
self._remote_address = (remote_addr, remote_port)
|
self._remote_address = (remote_addr, remote_port)
|
||||||
|
# pause reading
|
||||||
|
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
|
||||||
|
self._stage = STAGE_DNS
|
||||||
if self._is_local:
|
if self._is_local:
|
||||||
# forward address to remote
|
# forward address to remote
|
||||||
self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10',
|
self._write_to_sock('\x05\x00\x00\x01\x00\x00\x00\x00\x10\x10',
|
||||||
self._local_sock)
|
self._local_sock)
|
||||||
data_to_send = self._encryptor.encrypt(data)
|
data_to_send = self._encryptor.encrypt(data)
|
||||||
self._data_to_write_to_remote.append(data_to_send)
|
self._data_to_write_to_remote.append(data_to_send)
|
||||||
remote_addr = self._config['server']
|
# notice here may go into _handle_dns_resolved directly
|
||||||
remote_port = self._config['server_port']
|
self._dns_resolver.resolve(self._config['server'],
|
||||||
|
self._handle_dns_resolved)
|
||||||
else:
|
else:
|
||||||
if len(data) > header_length:
|
if len(data) > header_length:
|
||||||
self._data_to_write_to_remote.append(data[header_length:])
|
self._data_to_write_to_remote.append(data[header_length:])
|
||||||
|
# notice here may go into _handle_dns_resolved directly
|
||||||
|
self._dns_resolver.resolve(remote_addr,
|
||||||
|
self._handle_dns_resolved)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
traceback.print_exc()
|
||||||
|
# TODO use logging when debug completed
|
||||||
|
self.destroy()
|
||||||
|
|
||||||
# TODO async DNS
|
def _handle_dns_resolved(self, result, error):
|
||||||
addrs = socket.getaddrinfo(remote_addr, remote_port, 0,
|
if error:
|
||||||
socket.SOCK_STREAM, socket.SOL_TCP)
|
logging.error(error)
|
||||||
|
self.destroy()
|
||||||
|
return
|
||||||
|
if result:
|
||||||
|
ip = result[1]
|
||||||
|
if ip:
|
||||||
|
try:
|
||||||
|
self._stage = STAGE_REPLY
|
||||||
|
remote_addr = self._remote_address[0]
|
||||||
|
remote_port = self._remote_address[1]
|
||||||
|
if self._is_local:
|
||||||
|
remote_addr = self._config['server']
|
||||||
|
remote_port = self._config['server_port']
|
||||||
|
addrs = socket.getaddrinfo(ip, remote_port, 0,
|
||||||
|
socket.SOCK_STREAM,
|
||||||
|
socket.SOL_TCP)
|
||||||
if len(addrs) == 0:
|
if len(addrs) == 0:
|
||||||
raise Exception("can't get addrinfo for %s:%d" %
|
raise Exception("getaddrinfo failed for %s:%d" %
|
||||||
(remote_addr, remote_port))
|
(remote_addr, remote_port))
|
||||||
af, socktype, proto, canonname, sa = addrs[0]
|
af, socktype, proto, canonname, sa = addrs[0]
|
||||||
remote_sock = socket.socket(af, socktype, proto)
|
remote_sock = socket.socket(af, socktype, proto)
|
||||||
self._remote_sock = remote_sock
|
self._remote_sock = remote_sock
|
||||||
self._fd_to_handlers[remote_sock.fileno()] = self
|
self._fd_to_handlers[remote_sock.fileno()] = self
|
||||||
remote_sock.setblocking(False)
|
remote_sock.setblocking(False)
|
||||||
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY,
|
||||||
|
1)
|
||||||
|
|
||||||
if self._is_local and self._config['fast_open']:
|
if self._is_local and self._config['fast_open']:
|
||||||
# wait for more data to arrive and send them in one SYN
|
# wait for more data to arrive and send them in one SYN
|
||||||
self._stage = STAGE_REPLY
|
self._stage = STAGE_REPLY
|
||||||
self._loop.add(remote_sock, eventloop.POLL_ERR)
|
self._loop.add(remote_sock, eventloop.POLL_ERR)
|
||||||
|
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
|
||||||
# TODO when there is already data in this packet
|
# TODO when there is already data in this packet
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
remote_sock.connect(sa)
|
remote_sock.connect(sa)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
if eventloop.errno_from_exception(e) == errno.EINPROGRESS:
|
if eventloop.errno_from_exception(e) == \
|
||||||
|
errno.EINPROGRESS:
|
||||||
pass
|
pass
|
||||||
self._loop.add(remote_sock,
|
self._loop.add(remote_sock,
|
||||||
eventloop.POLL_ERR | eventloop.POLL_OUT)
|
eventloop.POLL_ERR | eventloop.POLL_OUT)
|
||||||
self._stage = STAGE_REPLY
|
self._stage = STAGE_REPLY
|
||||||
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
||||||
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
||||||
except Exception as e:
|
return
|
||||||
|
except (OSError, IOError) as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# TODO use logging when debug completed
|
|
||||||
self.destroy()
|
self.destroy()
|
||||||
|
|
||||||
def _on_local_read(self):
|
def _on_local_read(self):
|
||||||
|
@ -422,6 +454,7 @@ class TCPRelayHandler(object):
|
||||||
del self._fd_to_handlers[self._local_sock.fileno()]
|
del self._fd_to_handlers[self._local_sock.fileno()]
|
||||||
self._local_sock.close()
|
self._local_sock.close()
|
||||||
self._local_sock = None
|
self._local_sock = None
|
||||||
|
self._dns_resolver.remove_callback(self._handle_dns_resolved)
|
||||||
self._server.remove_handler(self)
|
self._server.remove_handler(self)
|
||||||
|
|
||||||
|
|
||||||
|
@ -545,7 +578,8 @@ class TCPRelay(object):
|
||||||
# logging.debug('accept')
|
# logging.debug('accept')
|
||||||
conn = self._server_socket.accept()
|
conn = self._server_socket.accept()
|
||||||
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
|
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
|
||||||
conn[0], self._config, self._is_local)
|
conn[0], self._config, self._dns_resolver,
|
||||||
|
self._is_local)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
error_no = eventloop.errno_from_exception(e)
|
error_no = eventloop.errno_from_exception(e)
|
||||||
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
|
||||||
|
|
Loading…
Reference in a new issue