add timeout support
This commit is contained in:
parent
0e662e04b6
commit
c5bcb9a050
2 changed files with 105 additions and 15 deletions
|
@ -30,6 +30,10 @@ import encrypt
|
|||
import eventloop
|
||||
from common import parse_header
|
||||
|
||||
|
||||
TIMEOUTS_CLEAN_SIZE = 512
|
||||
TIMEOUT_PRECISION = 4
|
||||
|
||||
CMD_CONNECT = 1
|
||||
CMD_BIND = 2
|
||||
CMD_UDP_ASSOCIATE = 3
|
||||
|
@ -66,7 +70,9 @@ BUF_SIZE = 8 * 1024
|
|||
|
||||
|
||||
class TCPRelayHandler(object):
|
||||
def __init__(self, fd_to_handlers, loop, local_sock, config, is_local):
|
||||
def __init__(self, server, fd_to_handlers, loop, local_sock, config,
|
||||
is_local):
|
||||
self._server = server
|
||||
self._fd_to_handlers = fd_to_handlers
|
||||
self._loop = loop
|
||||
self._local_sock = local_sock
|
||||
|
@ -80,10 +86,25 @@ class TCPRelayHandler(object):
|
|||
self._data_to_write_to_remote = []
|
||||
self._upstream_status = WAIT_STATUS_READING
|
||||
self._downstream_status = WAIT_STATUS_INIT
|
||||
self._remote_address = None
|
||||
fd_to_handlers[local_sock.fileno()] = self
|
||||
local_sock.setblocking(False)
|
||||
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||
self.last_activity = 0
|
||||
self.update_activity()
|
||||
|
||||
def __hash__(self):
|
||||
# default __hash__ is id / 16
|
||||
# we want to eliminate collisions
|
||||
return id(self)
|
||||
|
||||
@property
|
||||
def remote_address(self):
|
||||
return self._remote_address
|
||||
|
||||
def update_activity(self):
|
||||
self._server.update_activity(self)
|
||||
|
||||
def update_stream(self, stream, status):
|
||||
dirty = False
|
||||
|
@ -146,7 +167,7 @@ class TCPRelayHandler(object):
|
|||
logging.error('write_all_to_sock:unknown socket')
|
||||
|
||||
def on_local_read(self):
|
||||
# TODO update timeout
|
||||
self.update_activity()
|
||||
if not self._local_sock:
|
||||
return
|
||||
is_local = self._is_local
|
||||
|
@ -211,6 +232,7 @@ class TCPRelayHandler(object):
|
|||
addrtype, remote_addr, remote_port, header_length =\
|
||||
header_result
|
||||
logging.debug('connecting %s:%d' % (remote_addr, remote_port))
|
||||
self._remote_address = (remote_addr, remote_port)
|
||||
if is_local:
|
||||
# forward address to remote
|
||||
self.write_to_sock('\x05\x00\x00\x01' +
|
||||
|
@ -257,7 +279,7 @@ class TCPRelayHandler(object):
|
|||
self.destroy()
|
||||
|
||||
def on_remote_read(self):
|
||||
# TODO update timeout
|
||||
self.update_activity()
|
||||
data = None
|
||||
try:
|
||||
data = self._remote_sock.recv(BUF_SIZE)
|
||||
|
@ -325,6 +347,10 @@ class TCPRelayHandler(object):
|
|||
logging.warn('unknown socket')
|
||||
|
||||
def destroy(self):
|
||||
if self._remote_address:
|
||||
logging.debug('destroy: %s:%d' %
|
||||
self._remote_address)
|
||||
else:
|
||||
logging.debug('destroy')
|
||||
if self._remote_sock:
|
||||
self._loop.remove(self._remote_sock)
|
||||
|
@ -336,6 +362,7 @@ class TCPRelayHandler(object):
|
|||
del self._fd_to_handlers[self._local_sock.fileno()]
|
||||
self._local_sock.close()
|
||||
self._local_sock = None
|
||||
self._server.remove_handler(self)
|
||||
|
||||
|
||||
class TCPRelay(object):
|
||||
|
@ -347,6 +374,12 @@ class TCPRelay(object):
|
|||
self._fd_to_handlers = {}
|
||||
self._last_time = time.time()
|
||||
|
||||
self._timeout = config['timeout']
|
||||
self._timeouts = [] # a list for all the handlers
|
||||
self._timeout_offset = 0 # last checked position for timeout
|
||||
# we trim the timeouts once a while
|
||||
self._handler_to_timeouts = {} # key: handler value: index in timeouts
|
||||
|
||||
if is_local:
|
||||
listen_addr = config['local_address']
|
||||
listen_port = config['local_port']
|
||||
|
@ -376,19 +409,74 @@ class TCPRelay(object):
|
|||
self._eventloop.add(self._server_socket,
|
||||
eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||
|
||||
def remove_handler(self, handler):
|
||||
index = self._handler_to_timeouts.get(hash(handler), -1)
|
||||
if index >= 0:
|
||||
# delete is O(n), so we just set it to None
|
||||
self._timeouts[index] = None
|
||||
del self._handler_to_timeouts[hash(handler)]
|
||||
|
||||
def update_activity(self, handler):
|
||||
""" set handler to active """
|
||||
now = int(time.time())
|
||||
if now - handler.last_activity < TIMEOUT_PRECISION:
|
||||
# thus we can lower timeout modification frequency
|
||||
return
|
||||
handler.last_activity = now
|
||||
index = self._handler_to_timeouts.get(hash(handler), -1)
|
||||
if index >= 0:
|
||||
# delete is O(n), so we just set it to None
|
||||
self._timeouts[index] = None
|
||||
length = len(self._timeouts)
|
||||
self._timeouts.append(handler)
|
||||
self._handler_to_timeouts[hash(handler)] = length
|
||||
|
||||
def _sweep_timeout(self):
|
||||
# tornado's timeout memory management is more flexible that we need
|
||||
# we just need a sorted last_activity queue and it's faster that heapq
|
||||
# in fact we can do O(1) insertion/remove so we invent our own
|
||||
if self._timeouts:
|
||||
now = time.time()
|
||||
length = len(self._timeouts)
|
||||
pos = self._timeout_offset
|
||||
while pos < length:
|
||||
handler = self._timeouts[pos]
|
||||
if handler:
|
||||
if now - handler.last_activity < self._timeout:
|
||||
break
|
||||
else:
|
||||
if handler.remote_address:
|
||||
logging.warn('timed out: %s:%d' %
|
||||
handler.remote_address)
|
||||
else:
|
||||
logging.warn('timed out')
|
||||
handler.destroy()
|
||||
self._timeouts[pos] = None # free memory
|
||||
pos += 1
|
||||
else:
|
||||
pos += 1
|
||||
if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1:
|
||||
# clean up the timeout queue when it gets larger than half
|
||||
# of the queue
|
||||
self._timeouts = self._timeouts[pos:]
|
||||
for key in self._handler_to_timeouts:
|
||||
self._handler_to_timeouts[key] -= pos
|
||||
pos = 0
|
||||
self._timeout_offset = pos
|
||||
|
||||
def _handle_events(self, events):
|
||||
for sock, fd, event in events:
|
||||
if sock:
|
||||
logging.debug('fd %d %s', fd,
|
||||
eventloop.EVENT_NAMES.get(event, event))
|
||||
# if sock:
|
||||
# logging.debug('fd %d %s', fd,
|
||||
# eventloop.EVENT_NAMES.get(event, event))
|
||||
if sock == self._server_socket:
|
||||
if event & eventloop.POLL_ERR:
|
||||
# TODO
|
||||
raise Exception('server_socket error')
|
||||
try:
|
||||
logging.debug('accept')
|
||||
# logging.debug('accept')
|
||||
conn = self._server_socket.accept()
|
||||
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
|
||||
TCPRelayHandler(self, self._fd_to_handlers, self._eventloop,
|
||||
conn[0], self._config, self._is_local)
|
||||
except (OSError, IOError) as e:
|
||||
error_no = eventloop.errno_from_exception(e)
|
||||
|
@ -405,8 +493,8 @@ class TCPRelay(object):
|
|||
logging.warn('poll removed fd')
|
||||
|
||||
now = time.time()
|
||||
if now - self._last_time > 5:
|
||||
# TODO sweep timeouts
|
||||
if now - self._last_time > TIMEOUT_PRECISION:
|
||||
self._sweep_timeout()
|
||||
self._last_time = now
|
||||
|
||||
def close(self):
|
||||
|
|
|
@ -64,10 +64,10 @@ def check_config(config):
|
|||
if (config.get('method', '') or '').lower() == 'rc4':
|
||||
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
|
||||
'like AES-256-CFB')
|
||||
if (int(config.get('timeout', 300)) or 300) < 100:
|
||||
if config.get('timeout', 300) < 100:
|
||||
logging.warn('warning: your timeout %d seems too short' %
|
||||
int(config.get('timeout')))
|
||||
if (int(config.get('timeout', 300)) or 300) > 600:
|
||||
if config.get('timeout', 300) > 600:
|
||||
logging.warn('warning: your timeout %d seems too long' %
|
||||
int(config.get('timeout')))
|
||||
|
||||
|
@ -114,6 +114,8 @@ def get_config(is_local):
|
|||
config['local_address'] = value
|
||||
elif key == '-v':
|
||||
config['verbose'] = True
|
||||
elif key == '-t':
|
||||
config['timeout'] = int(value)
|
||||
elif key == '--fast-open':
|
||||
config['fast_open'] = True
|
||||
elif key == '--workers':
|
||||
|
|
Loading…
Reference in a new issue