new UDP over TCP protocol, merge master
This commit is contained in:
parent
469d9f7bfa
commit
a9ea55c396
5 changed files with 1198 additions and 214 deletions
|
@ -18,7 +18,6 @@
|
||||||
from __future__ import absolute_import, division, print_function, \
|
from __future__ import absolute_import, division, print_function, \
|
||||||
with_statement
|
with_statement
|
||||||
|
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
@ -256,7 +255,6 @@ class DNSResolver(object):
|
||||||
self._hostname_to_cb = {}
|
self._hostname_to_cb = {}
|
||||||
self._cb_to_hostname = {}
|
self._cb_to_hostname = {}
|
||||||
self._cache = lru_cache.LRUCache(timeout=300)
|
self._cache = lru_cache.LRUCache(timeout=300)
|
||||||
self._last_time = time.time()
|
|
||||||
self._sock = None
|
self._sock = None
|
||||||
self._servers = None
|
self._servers = None
|
||||||
self._parse_resolv()
|
self._parse_resolv()
|
||||||
|
@ -304,7 +302,7 @@ class DNSResolver(object):
|
||||||
except IOError:
|
except IOError:
|
||||||
self._hosts['localhost'] = '127.0.0.1'
|
self._hosts['localhost'] = '127.0.0.1'
|
||||||
|
|
||||||
def add_to_loop(self, loop, ref=False):
|
def add_to_loop(self, loop):
|
||||||
if self._loop:
|
if self._loop:
|
||||||
raise Exception('already add to loop')
|
raise Exception('already add to loop')
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
|
@ -312,8 +310,8 @@ class DNSResolver(object):
|
||||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||||
socket.SOL_UDP)
|
socket.SOL_UDP)
|
||||||
self._sock.setblocking(False)
|
self._sock.setblocking(False)
|
||||||
loop.add(self._sock, eventloop.POLL_IN)
|
loop.add(self._sock, eventloop.POLL_IN, self)
|
||||||
loop.add_handler(self.handle_events, ref=ref)
|
loop.add_periodic(self.handle_periodic)
|
||||||
|
|
||||||
def _call_callback(self, hostname, ip, error=None):
|
def _call_callback(self, hostname, ip, error=None):
|
||||||
callbacks = self._hostname_to_cb.get(hostname, [])
|
callbacks = self._hostname_to_cb.get(hostname, [])
|
||||||
|
@ -354,10 +352,9 @@ class DNSResolver(object):
|
||||||
self._call_callback(hostname, None)
|
self._call_callback(hostname, None)
|
||||||
break
|
break
|
||||||
|
|
||||||
def handle_events(self, events):
|
def handle_event(self, sock, fd, event):
|
||||||
for sock, fd, event in events:
|
|
||||||
if sock != self._sock:
|
if sock != self._sock:
|
||||||
continue
|
return
|
||||||
if event & eventloop.POLL_ERR:
|
if event & eventloop.POLL_ERR:
|
||||||
logging.error('dns socket err')
|
logging.error('dns socket err')
|
||||||
self._loop.remove(self._sock)
|
self._loop.remove(self._sock)
|
||||||
|
@ -366,18 +363,16 @@ class DNSResolver(object):
|
||||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||||
socket.SOL_UDP)
|
socket.SOL_UDP)
|
||||||
self._sock.setblocking(False)
|
self._sock.setblocking(False)
|
||||||
self._loop.add(self._sock, eventloop.POLL_IN)
|
self._loop.add(self._sock, eventloop.POLL_IN, self)
|
||||||
else:
|
else:
|
||||||
data, addr = sock.recvfrom(1024)
|
data, addr = sock.recvfrom(1024)
|
||||||
if addr[0] not in self._servers:
|
if addr[0] not in self._servers:
|
||||||
logging.warn('received a packet other than our dns')
|
logging.warn('received a packet other than our dns')
|
||||||
break
|
return
|
||||||
self._handle_data(data)
|
self._handle_data(data)
|
||||||
break
|
|
||||||
now = time.time()
|
def handle_periodic(self):
|
||||||
if now - self._last_time > CACHE_SWEEP_INTERVAL:
|
|
||||||
self._cache.sweep()
|
self._cache.sweep()
|
||||||
self._last_time = now
|
|
||||||
|
|
||||||
def remove_callback(self, callback):
|
def remove_callback(self, callback):
|
||||||
hostname = self._cb_to_hostname.get(callback)
|
hostname = self._cb_to_hostname.get(callback)
|
||||||
|
@ -430,6 +425,9 @@ class DNSResolver(object):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self._sock:
|
if self._sock:
|
||||||
|
if self._loop:
|
||||||
|
self._loop.remove_periodic(self.handle_periodic)
|
||||||
|
self._loop.remove(self._sock)
|
||||||
self._sock.close()
|
self._sock.close()
|
||||||
self._sock = None
|
self._sock = None
|
||||||
|
|
||||||
|
@ -437,7 +435,7 @@ class DNSResolver(object):
|
||||||
def test():
|
def test():
|
||||||
dns_resolver = DNSResolver()
|
dns_resolver = DNSResolver()
|
||||||
loop = eventloop.EventLoop()
|
loop = eventloop.EventLoop()
|
||||||
dns_resolver.add_to_loop(loop, ref=True)
|
dns_resolver.add_to_loop(loop)
|
||||||
|
|
||||||
global counter
|
global counter
|
||||||
counter = 0
|
counter = 0
|
||||||
|
@ -451,8 +449,8 @@ def test():
|
||||||
print(result, error)
|
print(result, error)
|
||||||
counter += 1
|
counter += 1
|
||||||
if counter == 9:
|
if counter == 9:
|
||||||
loop.remove_handler(dns_resolver.handle_events)
|
|
||||||
dns_resolver.close()
|
dns_resolver.close()
|
||||||
|
loop.stop()
|
||||||
a_callback = callback
|
a_callback = callback
|
||||||
return a_callback
|
return a_callback
|
||||||
|
|
||||||
|
|
|
@ -151,6 +151,15 @@ def pre_parse_header(data):
|
||||||
data = data[rand_data_size + 2:]
|
data = data[rand_data_size + 2:]
|
||||||
elif datatype == 0x81:
|
elif datatype == 0x81:
|
||||||
data = data[1:]
|
data = data[1:]
|
||||||
|
elif datatype == 0x82 :
|
||||||
|
if len(data) <= 3:
|
||||||
|
return None
|
||||||
|
rand_data_size = struct.unpack('>H', data[1:3])[0]
|
||||||
|
if rand_data_size + 3 >= len(data):
|
||||||
|
logging.warn('header too short, maybe wrong password or '
|
||||||
|
'encryption method')
|
||||||
|
return None
|
||||||
|
data = data[rand_data_size + 3:]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def parse_header(data):
|
def parse_header(data):
|
||||||
|
@ -158,8 +167,8 @@ def parse_header(data):
|
||||||
dest_addr = None
|
dest_addr = None
|
||||||
dest_port = None
|
dest_port = None
|
||||||
header_length = 0
|
header_length = 0
|
||||||
connecttype = (addrtype & 8) and 1 or 0
|
connecttype = (addrtype & 0x10) and 1 or 0
|
||||||
addrtype &= ~8
|
addrtype &= ~0x10
|
||||||
if addrtype == ADDRTYPE_IPV4:
|
if addrtype == ADDRTYPE_IPV4:
|
||||||
if len(data) >= 7:
|
if len(data) >= 7:
|
||||||
dest_addr = socket.inet_ntoa(data[1:5])
|
dest_addr = socket.inet_ntoa(data[1:5])
|
||||||
|
|
|
@ -22,6 +22,7 @@ from __future__ import absolute_import, division, print_function, \
|
||||||
with_statement
|
with_statement
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import socket
|
import socket
|
||||||
import select
|
import select
|
||||||
import errno
|
import errno
|
||||||
|
@ -51,23 +52,8 @@ EVENT_NAMES = {
|
||||||
POLL_NVAL: 'POLL_NVAL',
|
POLL_NVAL: 'POLL_NVAL',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# we check timeouts every TIMEOUT_PRECISION seconds
|
||||||
class EpollLoop(object):
|
TIMEOUT_PRECISION = 10
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._epoll = select.epoll()
|
|
||||||
|
|
||||||
def poll(self, timeout):
|
|
||||||
return self._epoll.poll(timeout)
|
|
||||||
|
|
||||||
def add_fd(self, fd, mode):
|
|
||||||
self._epoll.register(fd, mode)
|
|
||||||
|
|
||||||
def remove_fd(self, fd):
|
|
||||||
self._epoll.unregister(fd)
|
|
||||||
|
|
||||||
def modify_fd(self, fd, mode):
|
|
||||||
self._epoll.modify(fd, mode)
|
|
||||||
|
|
||||||
|
|
||||||
class KqueueLoop(object):
|
class KqueueLoop(object):
|
||||||
|
@ -100,17 +86,17 @@ class KqueueLoop(object):
|
||||||
results[fd] |= POLL_OUT
|
results[fd] |= POLL_OUT
|
||||||
return results.items()
|
return results.items()
|
||||||
|
|
||||||
def add_fd(self, fd, mode):
|
def register(self, fd, mode):
|
||||||
self._fds[fd] = mode
|
self._fds[fd] = mode
|
||||||
self._control(fd, mode, select.KQ_EV_ADD)
|
self._control(fd, mode, select.KQ_EV_ADD)
|
||||||
|
|
||||||
def remove_fd(self, fd):
|
def unregister(self, fd):
|
||||||
self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
|
self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
|
||||||
del self._fds[fd]
|
del self._fds[fd]
|
||||||
|
|
||||||
def modify_fd(self, fd, mode):
|
def modify(self, fd, mode):
|
||||||
self.remove_fd(fd)
|
self.unregister(fd)
|
||||||
self.add_fd(fd, mode)
|
self.register(fd, mode)
|
||||||
|
|
||||||
|
|
||||||
class SelectLoop(object):
|
class SelectLoop(object):
|
||||||
|
@ -129,7 +115,7 @@ class SelectLoop(object):
|
||||||
results[fd] |= p[1]
|
results[fd] |= p[1]
|
||||||
return results.items()
|
return results.items()
|
||||||
|
|
||||||
def add_fd(self, fd, mode):
|
def register(self, fd, mode):
|
||||||
if mode & POLL_IN:
|
if mode & POLL_IN:
|
||||||
self._r_list.add(fd)
|
self._r_list.add(fd)
|
||||||
if mode & POLL_OUT:
|
if mode & POLL_OUT:
|
||||||
|
@ -137,7 +123,7 @@ class SelectLoop(object):
|
||||||
if mode & POLL_ERR:
|
if mode & POLL_ERR:
|
||||||
self._x_list.add(fd)
|
self._x_list.add(fd)
|
||||||
|
|
||||||
def remove_fd(self, fd):
|
def unregister(self, fd):
|
||||||
if fd in self._r_list:
|
if fd in self._r_list:
|
||||||
self._r_list.remove(fd)
|
self._r_list.remove(fd)
|
||||||
if fd in self._w_list:
|
if fd in self._w_list:
|
||||||
|
@ -145,16 +131,15 @@ class SelectLoop(object):
|
||||||
if fd in self._x_list:
|
if fd in self._x_list:
|
||||||
self._x_list.remove(fd)
|
self._x_list.remove(fd)
|
||||||
|
|
||||||
def modify_fd(self, fd, mode):
|
def modify(self, fd, mode):
|
||||||
self.remove_fd(fd)
|
self.unregister(fd)
|
||||||
self.add_fd(fd, mode)
|
self.register(fd, mode)
|
||||||
|
|
||||||
|
|
||||||
class EventLoop(object):
|
class EventLoop(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._iterating = False
|
|
||||||
if hasattr(select, 'epoll'):
|
if hasattr(select, 'epoll'):
|
||||||
self._impl = EpollLoop()
|
self._impl = select.epoll()
|
||||||
model = 'epoll'
|
model = 'epoll'
|
||||||
elif hasattr(select, 'kqueue'):
|
elif hasattr(select, 'kqueue'):
|
||||||
self._impl = KqueueLoop()
|
self._impl = KqueueLoop()
|
||||||
|
@ -165,72 +150,71 @@ class EventLoop(object):
|
||||||
else:
|
else:
|
||||||
raise Exception('can not find any available functions in select '
|
raise Exception('can not find any available functions in select '
|
||||||
'package')
|
'package')
|
||||||
self._fd_to_f = {}
|
self._fdmap = {} # (f, handler)
|
||||||
self._handlers = []
|
self._last_time = time.time()
|
||||||
self._ref_handlers = []
|
self._periodic_callbacks = []
|
||||||
self._handlers_to_remove = []
|
self._stopping = False
|
||||||
logging.debug('using event model: %s', model)
|
logging.debug('using event model: %s', model)
|
||||||
|
|
||||||
def poll(self, timeout=None):
|
def poll(self, timeout=None):
|
||||||
events = self._impl.poll(timeout)
|
events = self._impl.poll(timeout)
|
||||||
return [(self._fd_to_f[fd], fd, event) for fd, event in events]
|
return [(self._fdmap[fd][0], fd, event) for fd, event in events]
|
||||||
|
|
||||||
def add(self, f, mode):
|
def add(self, f, mode, handler):
|
||||||
fd = f.fileno()
|
fd = f.fileno()
|
||||||
self._fd_to_f[fd] = f
|
self._fdmap[fd] = (f, handler)
|
||||||
self._impl.add_fd(fd, mode)
|
self._impl.register(fd, mode)
|
||||||
|
|
||||||
def remove(self, f):
|
def remove(self, f):
|
||||||
fd = f.fileno()
|
fd = f.fileno()
|
||||||
del self._fd_to_f[fd]
|
del self._fdmap[fd]
|
||||||
self._impl.remove_fd(fd)
|
self._impl.unregister(fd)
|
||||||
|
|
||||||
|
def add_periodic(self, callback):
|
||||||
|
self._periodic_callbacks.append(callback)
|
||||||
|
|
||||||
|
def remove_periodic(self, callback):
|
||||||
|
self._periodic_callbacks.remove(callback)
|
||||||
|
|
||||||
def modify(self, f, mode):
|
def modify(self, f, mode):
|
||||||
fd = f.fileno()
|
fd = f.fileno()
|
||||||
self._impl.modify_fd(fd, mode)
|
self._impl.modify(fd, mode)
|
||||||
|
|
||||||
def add_handler(self, handler, ref=True):
|
def stop(self):
|
||||||
self._handlers.append(handler)
|
self._stopping = True
|
||||||
if ref:
|
|
||||||
# when all ref handlers are removed, loop stops
|
|
||||||
self._ref_handlers.append(handler)
|
|
||||||
|
|
||||||
def remove_handler(self, handler):
|
|
||||||
if handler in self._ref_handlers:
|
|
||||||
self._ref_handlers.remove(handler)
|
|
||||||
if self._iterating:
|
|
||||||
self._handlers_to_remove.append(handler)
|
|
||||||
else:
|
|
||||||
self._handlers.remove(handler)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
events = []
|
events = []
|
||||||
while self._ref_handlers:
|
while not self._stopping:
|
||||||
|
asap = False
|
||||||
try:
|
try:
|
||||||
events = self.poll(1)
|
events = self.poll(TIMEOUT_PRECISION)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
||||||
# EPIPE: Happens when the client closes the connection
|
# EPIPE: Happens when the client closes the connection
|
||||||
# EINTR: Happens when received a signal
|
# EINTR: Happens when received a signal
|
||||||
# handles them as soon as possible
|
# handles them as soon as possible
|
||||||
|
asap = True
|
||||||
logging.debug('poll:%s', e)
|
logging.debug('poll:%s', e)
|
||||||
else:
|
else:
|
||||||
logging.error('poll:%s', e)
|
logging.error('poll:%s', e)
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
self._iterating = True
|
|
||||||
for handler in self._handlers:
|
for sock, fd, event in events:
|
||||||
# TODO when there are a lot of handlers
|
handler = self._fdmap.get(fd, None)
|
||||||
|
if handler is not None:
|
||||||
|
handler = handler[1]
|
||||||
try:
|
try:
|
||||||
handler(events)
|
handler.handle_event(sock, fd, event)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
shell.print_exception(e)
|
shell.print_exception(e)
|
||||||
if self._handlers_to_remove:
|
now = time.time()
|
||||||
for handler in self._handlers_to_remove:
|
if asap or now - self._last_time >= TIMEOUT_PRECISION:
|
||||||
self._handlers.remove(handler)
|
for callback in self._periodic_callbacks:
|
||||||
self._handlers_to_remove = []
|
callback()
|
||||||
self._iterating = False
|
self._last_time = now
|
||||||
|
|
||||||
|
|
||||||
# from tornado
|
# from tornado
|
||||||
|
|
|
@ -115,6 +115,7 @@ class TCPRelayHandler(object):
|
||||||
self._fastopen_connected = False
|
self._fastopen_connected = False
|
||||||
self._data_to_write_to_local = []
|
self._data_to_write_to_local = []
|
||||||
self._data_to_write_to_remote = []
|
self._data_to_write_to_remote = []
|
||||||
|
self._udp_data_send_buffer = ''
|
||||||
self._upstream_status = WAIT_STATUS_READING
|
self._upstream_status = WAIT_STATUS_READING
|
||||||
self._downstream_status = WAIT_STATUS_INIT
|
self._downstream_status = WAIT_STATUS_INIT
|
||||||
self._client_address = local_sock.getpeername()[:2]
|
self._client_address = local_sock.getpeername()[:2]
|
||||||
|
@ -128,7 +129,8 @@ class TCPRelayHandler(object):
|
||||||
fd_to_handlers[local_sock.fileno()] = self
|
fd_to_handlers[local_sock.fileno()] = self
|
||||||
local_sock.setblocking(False)
|
local_sock.setblocking(False)
|
||||||
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||||
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
|
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR,
|
||||||
|
self._server)
|
||||||
self.last_activity = 0
|
self.last_activity = 0
|
||||||
self._update_activity()
|
self._update_activity()
|
||||||
|
|
||||||
|
@ -185,6 +187,8 @@ class TCPRelayHandler(object):
|
||||||
if self._upstream_status & WAIT_STATUS_WRITING:
|
if self._upstream_status & WAIT_STATUS_WRITING:
|
||||||
event |= eventloop.POLL_OUT
|
event |= eventloop.POLL_OUT
|
||||||
self._loop.modify(self._remote_sock, event)
|
self._loop.modify(self._remote_sock, event)
|
||||||
|
if self._remote_sock_v6:
|
||||||
|
self._loop.modify(self._remote_sock_v6, event)
|
||||||
|
|
||||||
def _write_to_sock(self, data, sock):
|
def _write_to_sock(self, data, sock):
|
||||||
# write data to sock
|
# write data to sock
|
||||||
|
@ -193,20 +197,33 @@ class TCPRelayHandler(object):
|
||||||
if not data or not sock:
|
if not data or not sock:
|
||||||
return False
|
return False
|
||||||
#logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp))
|
#logging.debug("_write_to_sock %s %s %s" % (self._remote_sock, sock, self._remote_udp))
|
||||||
if self._remote_udp and self._remote_sock == sock:
|
uncomplete = False
|
||||||
|
if self._remote_udp and sock == self._remote_sock:
|
||||||
try:
|
try:
|
||||||
|
self._udp_data_send_buffer += data
|
||||||
|
#logging.info('UDP over TCP sendto %d %s' % (len(data), binascii.hexlify(data)))
|
||||||
|
while len(self._udp_data_send_buffer) > 6:
|
||||||
|
length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0]
|
||||||
|
|
||||||
|
if length > len(self._udp_data_send_buffer):
|
||||||
|
break
|
||||||
|
|
||||||
|
data = self._udp_data_send_buffer[:length]
|
||||||
|
self._udp_data_send_buffer = self._udp_data_send_buffer[length:]
|
||||||
|
|
||||||
frag = common.ord(data[2])
|
frag = common.ord(data[2])
|
||||||
if frag != 0:
|
if frag != 0:
|
||||||
logging.warn('drop a message since frag is %d' % (frag,))
|
logging.warn('drop a message since frag is %d' % (frag,))
|
||||||
return False
|
continue
|
||||||
else:
|
else:
|
||||||
data = data[3:]
|
data = data[3:]
|
||||||
header_result = parse_header(data)
|
header_result = parse_header(data)
|
||||||
if header_result is None:
|
if header_result is None:
|
||||||
return False
|
continue
|
||||||
connecttype, dest_addr, dest_port, header_length = header_result
|
connecttype, dest_addr, dest_port, header_length = header_result
|
||||||
addrs = socket.getaddrinfo(dest_addr, dest_port, 0,
|
addrs = socket.getaddrinfo(dest_addr, dest_port, 0,
|
||||||
socket.SOCK_DGRAM, socket.SOL_UDP)
|
socket.SOCK_DGRAM, socket.SOL_UDP)
|
||||||
|
#logging.info('UDP over TCP sendto %s:%d %d bytes from %s:%d' % (dest_addr, dest_port, len(data), self._client_address[0], self._client_address[1]))
|
||||||
if addrs:
|
if addrs:
|
||||||
af, socktype, proto, canonname, server_addr = addrs[0]
|
af, socktype, proto, canonname, server_addr = addrs[0]
|
||||||
data = data[header_length:]
|
data = data[header_length:]
|
||||||
|
@ -218,10 +235,16 @@ class TCPRelayHandler(object):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#trace = traceback.format_exc()
|
#trace = traceback.format_exc()
|
||||||
#logging.error(trace)
|
#logging.error(trace)
|
||||||
logging.error(e)
|
error_no = eventloop.errno_from_exception(e)
|
||||||
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
|
||||||
|
errno.EWOULDBLOCK):
|
||||||
|
uncomplete = True
|
||||||
|
else:
|
||||||
|
shell.print_exception(e)
|
||||||
|
self.destroy()
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
uncomplete = False
|
|
||||||
try:
|
try:
|
||||||
l = len(data)
|
l = len(data)
|
||||||
s = sock.send(data)
|
s = sock.send(data)
|
||||||
|
@ -270,7 +293,7 @@ class TCPRelayHandler(object):
|
||||||
remote_sock = \
|
remote_sock = \
|
||||||
self._create_remote_socket(self._chosen_server[0],
|
self._create_remote_socket(self._chosen_server[0],
|
||||||
self._chosen_server[1])
|
self._chosen_server[1])
|
||||||
self._loop.add(remote_sock, eventloop.POLL_ERR)
|
self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
|
||||||
data = b''.join(self._data_to_write_to_remote)
|
data = b''.join(self._data_to_write_to_remote)
|
||||||
l = len(data)
|
l = len(data)
|
||||||
s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server)
|
s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server)
|
||||||
|
@ -382,6 +405,11 @@ class TCPRelayHandler(object):
|
||||||
remote_sock_v6 = socket.socket(af, socktype, proto)
|
remote_sock_v6 = socket.socket(af, socktype, proto)
|
||||||
self._remote_sock_v6 = remote_sock_v6
|
self._remote_sock_v6 = remote_sock_v6
|
||||||
self._fd_to_handlers[remote_sock_v6.fileno()] = self
|
self._fd_to_handlers[remote_sock_v6.fileno()] = self
|
||||||
|
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
|
||||||
|
remote_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
|
||||||
|
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 32)
|
||||||
|
remote_sock_v6.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 32)
|
||||||
|
|
||||||
|
|
||||||
remote_sock.setblocking(False)
|
remote_sock.setblocking(False)
|
||||||
if self._remote_udp:
|
if self._remote_udp:
|
||||||
|
@ -421,10 +449,12 @@ class TCPRelayHandler(object):
|
||||||
remote_port)
|
remote_port)
|
||||||
if self._remote_udp:
|
if self._remote_udp:
|
||||||
self._loop.add(remote_sock,
|
self._loop.add(remote_sock,
|
||||||
eventloop.POLL_IN)
|
eventloop.POLL_IN,
|
||||||
|
self._server)
|
||||||
if self._remote_sock_v6:
|
if self._remote_sock_v6:
|
||||||
self._loop.add(self._remote_sock_v6,
|
self._loop.add(self._remote_sock_v6,
|
||||||
eventloop.POLL_IN)
|
eventloop.POLL_IN,
|
||||||
|
self._server)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
remote_sock.connect((remote_addr, remote_port))
|
remote_sock.connect((remote_addr, remote_port))
|
||||||
|
@ -433,10 +463,16 @@ class TCPRelayHandler(object):
|
||||||
errno.EINPROGRESS:
|
errno.EINPROGRESS:
|
||||||
pass
|
pass
|
||||||
self._loop.add(remote_sock,
|
self._loop.add(remote_sock,
|
||||||
eventloop.POLL_ERR | eventloop.POLL_OUT)
|
eventloop.POLL_ERR | eventloop.POLL_OUT,
|
||||||
|
self._server)
|
||||||
self._stage = STAGE_CONNECTING
|
self._stage = STAGE_CONNECTING
|
||||||
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
||||||
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
||||||
|
if self._remote_udp:
|
||||||
|
while self._data_to_write_to_remote:
|
||||||
|
data = self._data_to_write_to_remote[0]
|
||||||
|
del self._data_to_write_to_remote[0]
|
||||||
|
self._write_to_sock(data, self._remote_sock)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
shell.print_exception(e)
|
shell.print_exception(e)
|
||||||
|
@ -495,11 +531,12 @@ class TCPRelayHandler(object):
|
||||||
port = struct.pack('>H', addr[1])
|
port = struct.pack('>H', addr[1])
|
||||||
try:
|
try:
|
||||||
ip = socket.inet_aton(addr[0])
|
ip = socket.inet_aton(addr[0])
|
||||||
data = '\x00\x00\x00\x01' + ip + port + data
|
data = '\x00\x01' + ip + port + data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ip = socket.inet_pton(socket.AF_INET6, addr[0])
|
ip = socket.inet_pton(socket.AF_INET6, addr[0])
|
||||||
data = '\x00\x00\x00\x04' + ip + port + data
|
data = '\x00\x04' + ip + port + data
|
||||||
logging.info('UDP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
|
data = struct.pack('>H', len(data) + 2) + data
|
||||||
|
#logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1]))
|
||||||
else:
|
else:
|
||||||
data = self._remote_sock.recv(BUF_SIZE)
|
data = self._remote_sock.recv(BUF_SIZE)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
|
@ -637,7 +674,6 @@ class TCPRelay(object):
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._eventloop = None
|
self._eventloop = None
|
||||||
self._fd_to_handlers = {}
|
self._fd_to_handlers = {}
|
||||||
self._last_time = time.time()
|
|
||||||
self.server_transfer_ul = 0L
|
self.server_transfer_ul = 0L
|
||||||
self.server_transfer_dl = 0L
|
self.server_transfer_dl = 0L
|
||||||
|
|
||||||
|
@ -680,10 +716,9 @@ class TCPRelay(object):
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise Exception('already closed')
|
raise Exception('already closed')
|
||||||
self._eventloop = loop
|
self._eventloop = loop
|
||||||
loop.add_handler(self._handle_events)
|
|
||||||
|
|
||||||
self._eventloop.add(self._server_socket,
|
self._eventloop.add(self._server_socket,
|
||||||
eventloop.POLL_IN | eventloop.POLL_ERR)
|
eventloop.POLL_IN | eventloop.POLL_ERR, self)
|
||||||
|
self._eventloop.add_periodic(self.handle_periodic)
|
||||||
|
|
||||||
def remove_handler(self, handler):
|
def remove_handler(self, handler):
|
||||||
index = self._handler_to_timeouts.get(hash(handler), -1)
|
index = self._handler_to_timeouts.get(hash(handler), -1)
|
||||||
|
@ -695,7 +730,7 @@ class TCPRelay(object):
|
||||||
def update_activity(self, handler):
|
def update_activity(self, handler):
|
||||||
# set handler to active
|
# set handler to active
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
if now - handler.last_activity < TIMEOUT_PRECISION:
|
if now - handler.last_activity < eventloop.TIMEOUT_PRECISION:
|
||||||
# thus we can lower timeout modification frequency
|
# thus we can lower timeout modification frequency
|
||||||
return
|
return
|
||||||
handler.last_activity = now
|
handler.last_activity = now
|
||||||
|
@ -741,9 +776,8 @@ class TCPRelay(object):
|
||||||
pos = 0
|
pos = 0
|
||||||
self._timeout_offset = pos
|
self._timeout_offset = pos
|
||||||
|
|
||||||
def _handle_events(self, events):
|
def handle_event(self, sock, fd, event):
|
||||||
# handle events and dispatch to handlers
|
# handle events and dispatch to handlers
|
||||||
for sock, fd, event in events:
|
|
||||||
if sock:
|
if sock:
|
||||||
logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
|
logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd,
|
||||||
eventloop.EVENT_NAMES.get(event, event))
|
eventloop.EVENT_NAMES.get(event, event))
|
||||||
|
@ -761,7 +795,7 @@ class TCPRelay(object):
|
||||||
error_no = eventloop.errno_from_exception(e)
|
error_no = eventloop.errno_from_exception(e)
|
||||||
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
|
||||||
errno.EWOULDBLOCK):
|
errno.EWOULDBLOCK):
|
||||||
continue
|
return
|
||||||
else:
|
else:
|
||||||
shell.print_exception(e)
|
shell.print_exception(e)
|
||||||
if self._config['verbose']:
|
if self._config['verbose']:
|
||||||
|
@ -774,20 +808,23 @@ class TCPRelay(object):
|
||||||
else:
|
else:
|
||||||
logging.warn('poll removed fd')
|
logging.warn('poll removed fd')
|
||||||
|
|
||||||
now = time.time()
|
def handle_periodic(self):
|
||||||
if now - self._last_time > TIMEOUT_PRECISION:
|
|
||||||
self._sweep_timeout()
|
|
||||||
self._last_time = now
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
if self._server_socket:
|
if self._server_socket:
|
||||||
self._eventloop.remove(self._server_socket)
|
self._eventloop.remove(self._server_socket)
|
||||||
self._server_socket.close()
|
self._server_socket.close()
|
||||||
self._server_socket = None
|
self._server_socket = None
|
||||||
logging.info('closed listen port %d', self._listen_port)
|
logging.info('closed TCP port %d', self._listen_port)
|
||||||
if not self._fd_to_handlers:
|
if not self._fd_to_handlers:
|
||||||
self._eventloop.remove_handler(self._handle_events)
|
logging.info('stopping')
|
||||||
|
self._eventloop.stop()
|
||||||
|
self._sweep_timeout()
|
||||||
|
|
||||||
def close(self, next_tick=False):
|
def close(self, next_tick=False):
|
||||||
|
logging.debug('TCP close')
|
||||||
self._closed = True
|
self._closed = True
|
||||||
if not next_tick:
|
if not next_tick:
|
||||||
|
if self._eventloop:
|
||||||
|
self._eventloop.remove_periodic(self.handle_periodic)
|
||||||
|
self._eventloop.remove(self._server_socket)
|
||||||
self._server_socket.close()
|
self._server_socket.close()
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue