* make tcprelay.py less nested

* import traceback module at the top

* make loops in DNSResolver less nested

* make manager.py less nested

* introduce exception_handle decorator

make try/except block more clean

* apply exception_handle decorator to tcprelay

* quote condition judgement

* pep8 fix
This commit is contained in:
ahxxm 2016-10-07 12:30:17 +08:00 committed by mengskysama
parent 5c11527e1b
commit 5cd9f04948
6 changed files with 248 additions and 203 deletions

View File

@ -276,15 +276,18 @@ class DNSResolver(object):
content = f.readlines() content = f.readlines()
for line in content: for line in content:
line = line.strip() line = line.strip()
if line: if not (line and line.startswith(b'nameserver')):
if line.startswith(b'nameserver'): continue
parts = line.split()
if len(parts) >= 2: parts = line.split()
server = parts[1] if len(parts) < 2:
if common.is_ip(server) == socket.AF_INET: continue
if type(server) != str:
server = server.decode('utf8') server = parts[1]
self._servers.append(server) if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
except IOError: except IOError:
pass pass
if not self._servers: if not self._servers:
@ -299,13 +302,17 @@ class DNSResolver(object):
for line in f.readlines(): for line in f.readlines():
line = line.strip() line = line.strip()
parts = line.split() parts = line.split()
if len(parts) >= 2: if len(parts) < 2:
ip = parts[0] continue
if common.is_ip(ip):
for i in range(1, len(parts)): ip = parts[0]
hostname = parts[i] if not common.is_ip(ip):
if hostname: continue
self._hosts[hostname] = ip
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError: except IOError:
self._hosts['localhost'] = '127.0.0.1' self._hosts['localhost'] = '127.0.0.1'

View File

@ -25,6 +25,7 @@ import os
import time import time
import socket import socket
import select import select
import traceback
import errno import errno
import logging import logging
from collections import defaultdict from collections import defaultdict
@ -204,7 +205,6 @@ class EventLoop(object):
logging.debug('poll:%s', e) logging.debug('poll:%s', e)
else: else:
logging.error('poll:%s', e) logging.error('poll:%s', e)
import traceback
traceback.print_exc() traceback.print_exc()
continue continue

View File

@ -27,6 +27,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns
@shell.exception_handle(self_=False, exit_code=1)
def main(): def main():
shell.check_python() shell.check_python()
@ -37,36 +38,31 @@ def main():
os.chdir(p) os.chdir(p)
config = shell.get_config(True) config = shell.get_config(True)
daemon.daemon_exec(config) daemon.daemon_exec(config)
try: logging.info("starting local at %s:%d" %
logging.info("starting local at %s:%d" % (config['local_address'], config['local_port']))
(config['local_address'], config['local_port']))
dns_resolver = asyncdns.DNSResolver() dns_resolver = asyncdns.DNSResolver()
tcp_server = tcprelay.TCPRelay(config, dns_resolver, True) tcp_server = tcprelay.TCPRelay(config, dns_resolver, True)
udp_server = udprelay.UDPRelay(config, dns_resolver, True) udp_server = udprelay.UDPRelay(config, dns_resolver, True)
loop = eventloop.EventLoop() loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop) dns_resolver.add_to_loop(loop)
tcp_server.add_to_loop(loop) tcp_server.add_to_loop(loop)
udp_server.add_to_loop(loop) udp_server.add_to_loop(loop)
def handler(signum, _): def handler(signum, _):
logging.warn('received SIGQUIT, doing graceful shutting down..') logging.warn('received SIGQUIT, doing graceful shutting down..')
tcp_server.close(next_tick=True) tcp_server.close(next_tick=True)
udp_server.close(next_tick=True) udp_server.close(next_tick=True)
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler) signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler)
def int_handler(signum, _): def int_handler(signum, _):
sys.exit(1)
signal.signal(signal.SIGINT, int_handler)
daemon.set_user(config.get('user', None))
loop.run()
except Exception as e:
shell.print_exception(e)
sys.exit(1) sys.exit(1)
signal.signal(signal.SIGINT, int_handler)
daemon.set_user(config.get('user', None))
loop.run()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -173,18 +173,20 @@ class Manager(object):
self._statistics.clear() self._statistics.clear()
def _send_control_data(self, data): def _send_control_data(self, data):
if self._control_client_addr: if not self._control_client_addr:
try: return
self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e: try:
error_no = eventloop.errno_from_exception(e) self._control_socket.sendto(data, self._control_client_addr)
if error_no in (errno.EAGAIN, errno.EINPROGRESS, except (socket.error, OSError, IOError) as e:
errno.EWOULDBLOCK): error_no = eventloop.errno_from_exception(e)
return if error_no in (errno.EAGAIN, errno.EINPROGRESS,
else: errno.EWOULDBLOCK):
shell.print_exception(e) return
if self._config['verbose']: else:
traceback.print_exc() shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
def run(self): def run(self):
self._loop.run() self._loop.run()

View File

@ -23,6 +23,10 @@ import json
import sys import sys
import getopt import getopt
import logging import logging
import traceback
from functools import wraps
from shadowsocks.common import to_bytes, to_str, IPNetwork from shadowsocks.common import to_bytes, to_str, IPNetwork
from shadowsocks import encrypt from shadowsocks import encrypt
@ -53,6 +57,49 @@ def print_exception(e):
traceback.print_exc() traceback.print_exc()
def exception_handle(self_, err_msg=None, exit_code=None,
destroy=False, conn_err=False):
# self_: if function passes self as first arg
def process_exception(e, self=None):
print_exception(e)
if err_msg:
logging.error(err_msg)
if exit_code:
sys.exit(1)
if not self_:
return
if conn_err:
addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(e, addr, port))
if self._config['verbose']:
traceback.print_exc()
if destroy:
self.destroy()
def decorator(func):
if self_:
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
except Exception as e:
process_exception(e, self)
else:
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
process_exception(e)
return wrapper
return decorator
def print_shadowsocks(): def print_shadowsocks():
version = '' version = ''
try: try:

View File

@ -190,21 +190,23 @@ class TCPRelayHandler(object):
if self._upstream_status != status: if self._upstream_status != status:
self._upstream_status = status self._upstream_status = status
dirty = True dirty = True
if dirty: if not dirty:
if self._local_sock: return
event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_WRITING: if self._local_sock:
event |= eventloop.POLL_OUT event = eventloop.POLL_ERR
if self._upstream_status & WAIT_STATUS_READING: if self._downstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_IN event |= eventloop.POLL_OUT
self._loop.modify(self._local_sock, event) if self._upstream_status & WAIT_STATUS_READING:
if self._remote_sock: event |= eventloop.POLL_IN
event = eventloop.POLL_ERR self._loop.modify(self._local_sock, event)
if self._downstream_status & WAIT_STATUS_READING: if self._remote_sock:
event |= eventloop.POLL_IN event = eventloop.POLL_ERR
if self._upstream_status & WAIT_STATUS_WRITING: if self._downstream_status & WAIT_STATUS_READING:
event |= eventloop.POLL_OUT event |= eventloop.POLL_IN
self._loop.modify(self._remote_sock, event) if self._upstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT
self._loop.modify(self._remote_sock, event)
def _write_to_sock(self, data, sock): def _write_to_sock(self, data, sock):
# write data to sock # write data to sock
@ -247,19 +249,20 @@ class TCPRelayHandler(object):
return True return True
def _handle_stage_connecting(self, data): def _handle_stage_connecting(self, data):
if self._is_local: if not self._is_local:
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
else:
if self._ota_enable_session: if self._ota_enable_session:
self._ota_chunk_data(data, self._ota_chunk_data(data,
self._data_to_write_to_remote.append) self._data_to_write_to_remote.append)
else: else:
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \ return
self._config['fast_open']:
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
if self._config['fast_open'] and not self._fastopen_connected:
# for sslocal and fastopen, we basically wait for data and use # for sslocal and fastopen, we basically wait for data and use
# sendto to connect # sendto to connect
try: try:
@ -293,93 +296,88 @@ class TCPRelayHandler(object):
traceback.print_exc() traceback.print_exc()
self.destroy() self.destroy()
@shell.exception_handle(self_=True, destroy=True, conn_err=True)
def _handle_stage_addr(self, data): def _handle_stage_addr(self, data):
try: if self._is_local:
if self._is_local: cmd = common.ord(data[1])
cmd = common.ord(data[1]) if cmd == CMD_UDP_ASSOCIATE:
if cmd == CMD_UDP_ASSOCIATE: logging.debug('UDP associate')
logging.debug('UDP associate') if self._local_sock.family == socket.AF_INET6:
if self._local_sock.family == socket.AF_INET6: header = b'\x05\x00\x00\x04'
header = b'\x05\x00\x00\x04'
else:
header = b'\x05\x00\x00\x01'
addr, port = self._local_sock.getsockname()[:2]
addr_to_send = socket.inet_pton(self._local_sock.family,
addr)
port_to_send = struct.pack('>H', port)
self._write_to_sock(header + addr_to_send + port_to_send,
self._local_sock)
self._stage = STAGE_UDP_ASSOC
# just wait for the client to disconnect
return
elif cmd == CMD_CONNECT:
# just trim VER CMD RSV
data = data[3:]
else: else:
logging.error('unknown command %d', cmd) header = b'\x05\x00\x00\x01'
addr, port = self._local_sock.getsockname()[:2]
addr_to_send = socket.inet_pton(self._local_sock.family,
addr)
port_to_send = struct.pack('>H', port)
self._write_to_sock(header + addr_to_send + port_to_send,
self._local_sock)
self._stage = STAGE_UDP_ASSOC
# just wait for the client to disconnect
return
elif cmd == CMD_CONNECT:
# just trim VER CMD RSV
data = data[3:]
else:
logging.error('unknown command %d', cmd)
self.destroy()
return
header_result = parse_header(data)
if header_result is None:
raise Exception('can not parse header')
addrtype, remote_addr, remote_port, header_length = header_result
logging.info('connecting %s:%d from %s:%d' %
(common.to_str(remote_addr), remote_port,
self._client_address[0], self._client_address[1]))
if self._is_local is False:
# spec https://shadowsocks.org/en/spec/one-time-auth.html
self._ota_enable_session = addrtype & ADDRTYPE_AUTH
if self._ota_enable and not self._ota_enable_session:
logging.warn('client one time auth is required')
return
if self._ota_enable_session:
if len(data) < header_length + ONETIMEAUTH_BYTES:
logging.warn('one time auth header is too short')
return None
offset = header_length + ONETIMEAUTH_BYTES
_hash = data[header_length: offset]
_data = data[:header_length]
key = self._encryptor.decipher_iv + self._encryptor.key
if onetimeauth_verify(_hash, _data, key) is False:
logging.warn('one time auth fail')
self.destroy() self.destroy()
return return
header_result = parse_header(data) header_length += ONETIMEAUTH_BYTES
if header_result is None: self._remote_address = (common.to_str(remote_addr), remote_port)
raise Exception('can not parse header') # pause reading
addrtype, remote_addr, remote_port, header_length = header_result self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
logging.info('connecting %s:%d from %s:%d' % self._stage = STAGE_DNS
(common.to_str(remote_addr), remote_port, if self._is_local:
self._client_address[0], self._client_address[1])) # forward address to remote
if self._is_local is False: self._write_to_sock((b'\x05\x00\x00\x01'
# spec https://shadowsocks.org/en/spec/one-time-auth.html b'\x00\x00\x00\x00\x10\x10'),
self._ota_enable_session = addrtype & ADDRTYPE_AUTH self._local_sock)
if self._ota_enable and not self._ota_enable_session: # spec https://shadowsocks.org/en/spec/one-time-auth.html
logging.warn('client one time auth is required') # ATYP & 0x10 == 1, then OTA is enabled.
return if self._ota_enable_session:
if self._ota_enable_session: data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
if len(data) < header_length + ONETIMEAUTH_BYTES: key = self._encryptor.cipher_iv + self._encryptor.key
logging.warn('one time auth header is too short') data += onetimeauth_gen(data, key)
return None data_to_send = self._encryptor.encrypt(data)
offset = header_length + ONETIMEAUTH_BYTES self._data_to_write_to_remote.append(data_to_send)
_hash = data[header_length: offset] # notice here may go into _handle_dns_resolved directly
_data = data[:header_length] self._dns_resolver.resolve(self._chosen_server[0],
key = self._encryptor.decipher_iv + self._encryptor.key self._handle_dns_resolved)
if onetimeauth_verify(_hash, _data, key) is False: else:
logging.warn('one time auth fail') if self._ota_enable_session:
self.destroy() data = data[header_length:]
return self._ota_chunk_data(data,
header_length += ONETIMEAUTH_BYTES self._data_to_write_to_remote.append)
self._remote_address = (common.to_str(remote_addr), remote_port) elif len(data) > header_length:
# pause reading self._data_to_write_to_remote.append(data[header_length:])
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) # notice here may go into _handle_dns_resolved directly
self._stage = STAGE_DNS self._dns_resolver.resolve(remote_addr,
if self._is_local: self._handle_dns_resolved)
# forward address to remote
self._write_to_sock((b'\x05\x00\x00\x01'
b'\x00\x00\x00\x00\x10\x10'),
self._local_sock)
# spec https://shadowsocks.org/en/spec/one-time-auth.html
# ATYP & 0x10 == 1, then OTA is enabled.
if self._ota_enable_session:
data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
key = self._encryptor.cipher_iv + self._encryptor.key
data += onetimeauth_gen(data, key)
data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved)
else:
if self._ota_enable_session:
data = data[header_length:]
self._ota_chunk_data(data,
self._data_to_write_to_remote.append)
elif len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved)
except Exception as e:
self._log_error(e)
if self._config['verbose']:
traceback.print_exc()
self.destroy()
def _create_remote_socket(self, ip, port): def _create_remote_socket(self, ip, port):
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
@ -398,51 +396,50 @@ class TCPRelayHandler(object):
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
return remote_sock return remote_sock
@shell.exception_handle(self_=True)
def _handle_dns_resolved(self, result, error): def _handle_dns_resolved(self, result, error):
if error: if error:
self._log_error(error) addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(error, addr, port))
self.destroy()
return
if not (result and result[1]):
self.destroy() self.destroy()
return return
if result and result[1]:
ip = result[1]
try:
self._stage = STAGE_CONNECTING
remote_addr = ip
if self._is_local:
remote_port = self._chosen_server[1]
else:
remote_port = self._remote_address[1]
if self._is_local and self._config['fast_open']: ip = result[1]
# for fastopen: self._stage = STAGE_CONNECTING
# wait for more data arrive and send them in one SYN remote_addr = ip
self._stage = STAGE_CONNECTING if self._is_local:
# we don't have to wait for remote since it's not remote_port = self._chosen_server[1]
# created else:
self._update_stream(STREAM_UP, WAIT_STATUS_READING) remote_port = self._remote_address[1]
# TODO when there is already data in this packet
else: if self._is_local and self._config['fast_open']:
# else do connect # for fastopen:
remote_sock = self._create_remote_socket(remote_addr, # wait for more data arrive and send them in one SYN
remote_port) self._stage = STAGE_CONNECTING
try: # we don't have to wait for remote since it's not
remote_sock.connect((remote_addr, remote_port)) # created
except (OSError, IOError) as e: self._update_stream(STREAM_UP, WAIT_STATUS_READING)
if eventloop.errno_from_exception(e) == \ # TODO when there is already data in this packet
errno.EINPROGRESS: else:
pass # else do connect
self._loop.add(remote_sock, remote_sock = self._create_remote_socket(remote_addr,
eventloop.POLL_ERR | eventloop.POLL_OUT, remote_port)
self._server) try:
self._stage = STAGE_CONNECTING remote_sock.connect((remote_addr, remote_port))
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) except (OSError, IOError) as e:
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) if eventloop.errno_from_exception(e) == \
return errno.EINPROGRESS:
except Exception as e: pass
shell.print_exception(e) self._loop.add(remote_sock,
if self._config['verbose']: eventloop.POLL_ERR | eventloop.POLL_OUT,
traceback.print_exc() self._server)
self.destroy() self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
def _write_to_sock_remote(self, data): def _write_to_sock_remote(self, data):
self._write_to_sock(data, self._remote_sock) self._write_to_sock(data, self._remote_sock)
@ -661,10 +658,6 @@ class TCPRelayHandler(object):
else: else:
logging.warn('unknown socket') logging.warn('unknown socket')
def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1]))
def destroy(self): def destroy(self):
# destroy the handler and release any resources # destroy the handler and release any resources
# promises: # promises: