add timeout support
This commit is contained in:
parent
0e662e04b6
commit
c5bcb9a050
2 changed files with 105 additions and 15 deletions
|
@ -30,6 +30,10 @@ import encrypt
|
||||||
import eventloop
|
import eventloop
|
||||||
from common import parse_header
|
from common import parse_header
|
||||||
|
|
||||||
|
|
||||||
|
TIMEOUTS_CLEAN_SIZE = 512
|
||||||
|
TIMEOUT_PRECISION = 4
|
||||||
|
|
||||||
CMD_CONNECT = 1
|
CMD_CONNECT = 1
|
||||||
CMD_BIND = 2
|
CMD_BIND = 2
|
||||||
CMD_UDP_ASSOCIATE = 3
|
CMD_UDP_ASSOCIATE = 3
|
||||||
|
@ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024
|
||||||
|
|
||||||
|
|
||||||
class TCPRelayHandler(object):
|
class TCPRelayHandler(object):
|
||||||
def __init__(self, fd_to_handlers, loop, local_sock, config, is_local):
|
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
|
||||||
|
is_local):
|
||||||
|
self._server = server
|
||||||
self._fd_to_handlers = fd_to_handlers
|
self._fd_to_handlers = fd_to_handlers
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._local_sock = local_sock
|
self._local_sock = local_sock
|
||||||
|
@ -80,10 +86,25 @@ class TCPRelayHandler(object):
|
||||||
self._data_to_write_to_remote = []
|
self._data_to_write_to_remote = []
|
||||||
self._upstream_status = WAIT_STATUS_READING
|
self._upstream_status = WAIT_STATUS_READING
|
||||||
self._downstream_status = WAIT_STATUS_INIT
|
self._downstream_status = WAIT_STATUS_INIT
|
||||||
|
self._remote_address = None
|
||||||
fd_to_handlers[local_sock.fileno()] = self
|
fd_to_handlers[local_sock.fileno()] = self
|
||||||
local_sock.setblocking(False)
|
local_sock.setblocking(False)
|
||||||
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||||
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
|
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||||
|
self.last_activity = 0
|
||||||
|
self.update_activity()
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# default __hash__ is id / 16
|
||||||
|
# we want to eliminate collisions
|
||||||
|
return id(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_address(self):
|
||||||
|
return self._remote_address
|
||||||
|
|
||||||
|
def update_activity(self):
|
||||||
|
self._server.update_activity(self)
|
||||||
|
|
||||||
def update_stream(self, stream, status):
|
def update_stream(self, stream, status):
|
||||||
dirty = False
|
dirty = False
|
||||||
|
@ -146,7 +167,7 @@ class TCPRelayHandler(object):
|
||||||
logging.error('write_all_to_sock:unknown socket')
|
logging.error('write_all_to_sock:unknown socket')
|
||||||
|
|
||||||
def on_local_read(self):
|
def on_local_read(self):
|
||||||
# TODO update timeout
|
self.update_activity()
|
||||||
if not self._local_sock:
|
if not self._local_sock:
|
||||||
return
|
return
|
||||||
is_local = self._is_local
|
is_local = self._is_local
|
||||||
|
@ -211,6 +232,7 @@ class TCPRelayHandler(object):
|
||||||
addrtype, remote_addr, remote_port, header_length =\
|
addrtype, remote_addr, remote_port, header_length =\
|
||||||
header_result
|
header_result
|
||||||
logging.debug('connecting %s:%d' % (remote_addr, remote_port))
|
logging.debug('connecting %s:%d' % (remote_addr, remote_port))
|
||||||
|
self._remote_address = (remote_addr, remote_port)
|
||||||
if is_local:
|
if is_local:
|
||||||
# forward address to remote
|
# forward address to remote
|
||||||
self.write_to_sock('\x05\x00\x00\x01' +
|
self.write_to_sock('\x05\x00\x00\x01' +
|
||||||
|
@ -257,7 +279,7 @@ class TCPRelayHandler(object):
|
||||||
self.destroy()
|
self.destroy()
|
||||||
|
|
||||||
def on_remote_read(self):
|
def on_remote_read(self):
|
||||||
# TODO update timeout
|
self.update_activity()
|
||||||
data = None
|
data = None
|
||||||
try:
|
try:
|
||||||
data = self._remote_sock.recv(BUF_SIZE)
|
data = self._remote_sock.recv(BUF_SIZE)
|
||||||
|
@ -325,6 +347,10 @@ class TCPRelayHandler(object):
|
||||||
logging.warn('unknown socket')
|
logging.warn('unknown socket')
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
|
if self._remote_address:
|
||||||
|
logging.debug('destroy: %s:%d' %
|
||||||
|
self._remote_address)
|
||||||
|
else:
|
||||||
logging.debug('destroy')
|
logging.debug('destroy')
|
||||||
if self._remote_sock:
|
if self._remote_sock:
|
||||||
self._loop.remove(self._remote_sock)
|
self._loop.remove(self._remote_sock)
|
||||||
|
@ -336,6 +362,7 @@ class TCPRelayHandler(object):
|
||||||
del self._fd_to_handlers[self._local_sock.fileno()]
|
del self._fd_to_handlers[self._local_sock.fileno()]
|
||||||
self._local_sock.close()
|
self._local_sock.close()
|
||||||
self._local_sock = None
|
self._local_sock = None
|
||||||
|
self._server.remove_handler(self)
|
||||||
|
|
||||||
|
|
||||||
class TCPRelay(object):
|
class TCPRelay(object):
|
||||||
|
@ -347,6 +374,12 @@ class TCPRelay(object):
|
||||||
self._fd_to_handlers = {}
|
self._fd_to_handlers = {}
|
||||||
self._last_time = time.time()
|
self._last_time = time.time()
|
||||||
|
|
||||||
|
self._timeout = config['timeout']
|
||||||
|
self._timeouts = [] # a list for all the handlers
|
||||||
|
self._timeout_offset = 0 # last checked position for timeout
|
||||||
|
# we trim the timeouts once a while
|
||||||
|
self._handler_to_timeouts = {} # key: handler value: index in timeouts
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
listen_addr = config['local_address']
|
listen_addr = config['local_address']
|
||||||
listen_port = config['local_port']
|
listen_port = config['local_port']
|
||||||
|
@ -376,19 +409,74 @@ class TCPRelay(object):
|
||||||
self._eventloop.add(self._server_socket,
|
self._eventloop.add(self._server_socket,
|
||||||
eventloop.POLL_IN | eventloop.POLL_ERR)
|
eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||||
|
|
||||||
|
def remove_handler(self, handler):
|
||||||
|
index = self._handler_to_timeouts.get(hash(handler), -1)
|
||||||
|
if index >= 0:
|
||||||
|
# delete is O(n), so we just set it to None
|
||||||
|
self._timeouts[index] = None
|
||||||
|
del self._handler_to_timeouts[hash(handler)]
|
||||||
|
|
||||||
|
def update_activity(self, handler):
|
||||||
|
""" set handler to active """
|
||||||
|
now = int(time.time())
|
||||||
|
if now - handler.last_activity < TIMEOUT_PRECISION:
|
||||||
|
# thus we can lower timeout modification frequency
|
||||||
|
return
|
||||||
|
handler.last_activity = now
|
||||||
|
index = self._handler_to_timeouts.get(hash(handler), -1)
|
||||||
|
if index >= 0:
|
||||||
|
# delete is O(n), so we just set it to None
|
||||||
|
self._timeouts[index] = None
|
||||||
|
length = len(self._timeouts)
|
||||||
|
self._timeouts.append(handler)
|
||||||
|
self._handler_to_timeouts[hash(handler)] = length
|
||||||
|
|
||||||
|
def _sweep_timeout(self):
|
||||||
|
# tornado's timeout memory management is more flexible that we need
|
||||||
|
# we just need a sorted last_activity queue and it's faster that heapq
|
||||||
|
# in fact we can do O(1) insertion/remove so we invent our own
|
||||||
|
if self._timeouts:
|
||||||
|
now = time.time()
|
||||||
|
length = len(self._timeouts)
|
||||||
|
pos = self._timeout_offset
|
||||||
|
while pos < length:
|
||||||
|
handler = self._timeouts[pos]
|
||||||
|
if handler:
|
||||||
|
if now - handler.last_activity < self._timeout:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if handler.remote_address:
|
||||||
|
logging.warn('timed out: %s:%d' %
|
||||||
|
handler.remote_address)
|
||||||
|
else:
|
||||||
|
logging.warn('timed out')
|
||||||
|
handler.destroy()
|
||||||
|
self._timeouts[pos] = None # free memory
|
||||||
|
pos += 1
|
||||||
|
else:
|
||||||
|
pos += 1
|
||||||
|
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1:
|
||||||
|
# clean up the timeout queue when it gets larger than half
|
||||||
|
# of the queue
|
||||||
|
self._timeouts = self._timeouts[pos:]
|
||||||
|
for key in self._handler_to_timeouts:
|
||||||
|
self._handler_to_timeouts[key] -= pos
|
||||||
|
pos = 0
|
||||||
|
self._timeout_offset = pos
|
||||||
|
|
||||||
def _handle_events(self, events):
|
def _handle_events(self, events):
|
||||||
for sock, fd, event in events:
|
for sock, fd, event in events:
|
||||||
if sock:
|
# if sock:
|
||||||
logging.debug('fd %d %s', fd,
|
# logging.debug('fd %d %s', fd,
|
||||||
eventloop.EVENT_NAMES.get(event, event))
|
# eventloop.EVENT_NAMES.get(event, event))
|
||||||
if sock == self._server_socket:
|
if sock == self._server_socket:
|
||||||
if event & eventloop.POLL_ERR:
|
if event & eventloop.POLL_ERR:
|
||||||
# TODO
|
# TODO
|
||||||
raise Exception('server_socket error')
|
raise Exception('server_socket error')
|
||||||
try:
|
try:
|
||||||
logging.debug('accept')
|
# logging.debug('accept')
|
||||||
conn = self._server_socket.accept()
|
conn = self._server_socket.accept()
|
||||||
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
|
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
|
||||||
conn[0], self._config, self._is_local)
|
conn[0], self._config, self._is_local)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
error_no = eventloop.errno_from_exception(e)
|
error_no = eventloop.errno_from_exception(e)
|
||||||
|
@ -405,8 +493,8 @@ class TCPRelay(object):
|
||||||
logging.warn('poll removed fd')
|
logging.warn('poll removed fd')
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if now - self._last_time > 5:
|
if now - self._last_time > TIMEOUT_PRECISION:
|
||||||
# TODO sweep timeouts
|
self._sweep_timeout()
|
||||||
self._last_time = now
|
self._last_time = now
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
@ -64,10 +64,10 @@ def check_config(config):
|
||||||
if (config.get('method', '') or '').lower() == 'rc4':
|
if (config.get('method', '') or '').lower() == 'rc4':
|
||||||
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
|
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
|
||||||
'like AES-256-CFB')
|
'like AES-256-CFB')
|
||||||
if (int(config.get('timeout', 300)) or 300) < 100:
|
if config.get('timeout', 300) < 100:
|
||||||
logging.warn('warning: your timeout %d seems too short' %
|
logging.warn('warning: your timeout %d seems too short' %
|
||||||
int(config.get('timeout')))
|
int(config.get('timeout')))
|
||||||
if (int(config.get('timeout', 300)) or 300) > 600:
|
if config.get('timeout', 300) > 600:
|
||||||
logging.warn('warning: your timeout %d seems too long' %
|
logging.warn('warning: your timeout %d seems too long' %
|
||||||
int(config.get('timeout')))
|
int(config.get('timeout')))
|
||||||
|
|
||||||
|
@ -114,6 +114,8 @@ def get_config(is_local):
|
||||||
config['local_address'] = value
|
config['local_address'] = value
|
||||||
elif key == '-v':
|
elif key == '-v':
|
||||||
config['verbose'] = True
|
config['verbose'] = True
|
||||||
|
elif key == '-t':
|
||||||
|
config['timeout'] = int(value)
|
||||||
elif key == '--fast-open':
|
elif key == '--fast-open':
|
||||||
config['fast_open'] = True
|
config['fast_open'] = True
|
||||||
elif key == '--workers':
|
elif key == '--workers':
|
||||||
|
|
Loading…
Reference in a new issue