add config "udp_timeout"

drop TCP over UDP query
This commit is contained in:
BreakWa11 2016-05-17 17:27:33 +08:00
parent 4e43a4932d
commit d47b0d47ea
2 changed files with 70 additions and 65 deletions

View file

@ -236,6 +236,7 @@ def get_config(is_local):
config['obfs_param'] = to_str(config.get('obfs_param', '')) config['obfs_param'] = to_str(config.get('obfs_param', ''))
config['port_password'] = config.get('port_password', None) config['port_password'] = config.get('port_password', None)
config['timeout'] = int(config.get('timeout', 300)) config['timeout'] = int(config.get('timeout', 300))
config['udp_timeout'] = int(config.get('udp_timeout', config['timeout']))
config['fast_open'] = config.get('fast_open', False) config['fast_open'] = config.get('fast_open', False)
config['workers'] = config.get('workers', 1) config['workers'] = config.get('workers', 1)
config['pid-file'] = config.get('pid-file', '/var/run/shadowsocks.pid') config['pid-file'] = config.get('pid-file', '/var/run/shadowsocks.pid')

View file

@ -880,7 +880,7 @@ class UDPRelay(object):
self._method = config['method'] self._method = config['method']
self._timeout = config['timeout'] self._timeout = config['timeout']
self._is_local = is_local self._is_local = is_local
self._cache = lru_cache.LRUCache(timeout=config['timeout'], self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'],
close_callback=self._close_client) close_callback=self._close_client)
self._client_fd_to_server_addr = {} self._client_fd_to_server_addr = {}
self._dns_cache = lru_cache.LRUCache(timeout=300) self._dns_cache = lru_cache.LRUCache(timeout=300)
@ -1023,6 +1023,86 @@ class UDPRelay(object):
return return
if type(data) is tuple: if type(data) is tuple:
return
return self._handle_tcp_over_udp(data, r_addr)
try:
header_result = parse_header(data)
except:
self._handel_protocol_error(r_addr, ogn_data)
return
if header_result is None:
self._handel_protocol_error(r_addr, ogn_data)
return
connecttype, dest_addr, dest_port, header_length = header_result
if self._is_local:
server_addr, server_port = self._get_a_server()
else:
server_addr, server_port = dest_addr, dest_port
addrs = self._dns_cache.get(server_addr, None)
if addrs is None:
# TODO async getaddrinfo
addrs = socket.getaddrinfo(server_addr, server_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs:
# drop
return
else:
self._dns_cache[server_addr] = addrs
af, socktype, proto, canonname, sa = addrs[0]
key = client_key(r_addr, af)
client = self._cache.get(key, None)
if not client:
if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist:
logging.debug('IP %s is in forbidden list, drop' %
common.to_str(sa[0]))
# drop
return
client = socket.socket(af, socktype, proto)
client.setblocking(False)
self._cache[key] = client
self._client_fd_to_server_addr[client.fileno()] = r_addr
self._sockets.add(client.fileno())
self._eventloop.add(client, eventloop.POLL_IN, self)
logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets)))
logging.info('UDP data to %s:%d from %s:%d' %
(common.to_str(server_addr), server_port,
r_addr[0], r_addr[1]))
self._cache.clear(256)
if self._is_local:
ref_iv = [encrypt.encrypt_new_iv(self._method)]
self._protocol.obfs.server_info.iv = ref_iv[0]
data = self._protocol.client_udp_pre_encrypt(data)
logging.info("%s" % (binascii.hexlify(data),))
data = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, data, ref_iv)
if not data:
return
else:
data = data[header_length:]
if not data:
return
try:
#logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data)))
client.sendto(data, (server_addr, server_port))
self.server_transfer_ul += len(data)
except IOError as e:
err = eventloop.errno_from_exception(e)
if err in (errno.EINPROGRESS, errno.EAGAIN):
pass
else:
shell.print_exception(e)
def _handle_tcp_over_udp(self, data, r_addr):
#(cmd, request_id, data) #(cmd, request_id, data)
#logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2]))) #logging.info("UDP data %d %d %s" % (data[0], data[1], binascii.hexlify(data[2])))
try: try:
@ -1087,82 +1167,6 @@ class UDPRelay(object):
logging.error(trace) logging.error(trace)
return return
try:
header_result = parse_header(data)
except:
self._handel_protocol_error(r_addr, ogn_data)
return
if header_result is None:
self._handel_protocol_error(r_addr, ogn_data)
return
connecttype, dest_addr, dest_port, header_length = header_result
if self._is_local:
server_addr, server_port = self._get_a_server()
else:
server_addr, server_port = dest_addr, dest_port
addrs = self._dns_cache.get(server_addr, None)
if addrs is None:
addrs = socket.getaddrinfo(server_addr, server_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
if not addrs:
# drop
return
else:
self._dns_cache[server_addr] = addrs
af, socktype, proto, canonname, sa = addrs[0]
key = client_key(r_addr, af)
client = self._cache.get(key, None)
if not client:
# TODO async getaddrinfo
if self._forbidden_iplist:
if common.to_str(sa[0]) in self._forbidden_iplist:
logging.debug('IP %s is in forbidden list, drop' %
common.to_str(sa[0]))
# drop
return
client = socket.socket(af, socktype, proto)
client.setblocking(False)
self._cache[key] = client
self._client_fd_to_server_addr[client.fileno()] = r_addr
self._sockets.add(client.fileno())
self._eventloop.add(client, eventloop.POLL_IN, self)
logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets)))
logging.info('UDP data to %s:%d from %s:%d' %
(common.to_str(server_addr), server_port,
r_addr[0], r_addr[1]))
self._cache.clear(256)
if self._is_local:
ref_iv = [encrypt.encrypt_new_iv(self._method)]
self._protocol.obfs.server_info.iv = ref_iv[0]
data = self._protocol.client_udp_pre_encrypt(data)
logging.info("%s" % (binascii.hexlify(data),))
data = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, data, ref_iv)
if not data:
return
else:
data = data[header_length:]
if not data:
return
try:
#logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data)))
client.sendto(data, (server_addr, server_port))
self.server_transfer_ul += len(data)
except IOError as e:
err = eventloop.errno_from_exception(e)
if err in (errno.EINPROGRESS, errno.EAGAIN):
pass
else:
shell.print_exception(e)
def _handle_client(self, sock): def _handle_client(self, sock):
data, r_addr = sock.recvfrom(BUF_SIZE) data, r_addr = sock.recvfrom(BUF_SIZE)
if not data: if not data: