* make tcprelay.py less nested

* import traceback module at the top

* make loops in DNSResolver less nested

* make manager.py less nested

* introduce exception_handle decorator

make try/except block more clean

* apply exception_handle decorator to tcprelay

* quote condition judgement

* pep8 fix
This commit is contained in:
ahxxm 2016-10-07 12:30:17 +08:00 committed by mengskysama
parent 5c11527e1b
commit 5cd9f04948
6 changed files with 248 additions and 203 deletions

View File

@ -276,15 +276,18 @@ class DNSResolver(object):
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith(b'nameserver'):
parts = line.split()
if len(parts) >= 2:
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
if not (line and line.startswith(b'nameserver')):
continue
parts = line.split()
if len(parts) < 2:
continue
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
except IOError:
pass
if not self._servers:
@ -299,13 +302,17 @@ class DNSResolver(object):
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) >= 2:
ip = parts[0]
if common.is_ip(ip):
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
if len(parts) < 2:
continue
ip = parts[0]
if not common.is_ip(ip):
continue
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'

View File

@ -25,6 +25,7 @@ import os
import time
import socket
import select
import traceback
import errno
import logging
from collections import defaultdict
@ -204,7 +205,6 @@ class EventLoop(object):
logging.debug('poll:%s', e)
else:
logging.error('poll:%s', e)
import traceback
traceback.print_exc()
continue

View File

@ -27,6 +27,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns
@shell.exception_handle(self_=False, exit_code=1)
def main():
shell.check_python()
@ -37,36 +38,31 @@ def main():
os.chdir(p)
config = shell.get_config(True)
daemon.daemon_exec(config)
try:
logging.info("starting local at %s:%d" %
(config['local_address'], config['local_port']))
logging.info("starting local at %s:%d" %
(config['local_address'], config['local_port']))
dns_resolver = asyncdns.DNSResolver()
tcp_server = tcprelay.TCPRelay(config, dns_resolver, True)
udp_server = udprelay.UDPRelay(config, dns_resolver, True)
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop)
tcp_server.add_to_loop(loop)
udp_server.add_to_loop(loop)
dns_resolver = asyncdns.DNSResolver()
tcp_server = tcprelay.TCPRelay(config, dns_resolver, True)
udp_server = udprelay.UDPRelay(config, dns_resolver, True)
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop)
tcp_server.add_to_loop(loop)
udp_server.add_to_loop(loop)
def handler(signum, _):
logging.warn('received SIGQUIT, doing graceful shutting down..')
tcp_server.close(next_tick=True)
udp_server.close(next_tick=True)
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler)
def handler(signum, _):
logging.warn('received SIGQUIT, doing graceful shutting down..')
tcp_server.close(next_tick=True)
udp_server.close(next_tick=True)
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler)
def int_handler(signum, _):
sys.exit(1)
signal.signal(signal.SIGINT, int_handler)
daemon.set_user(config.get('user', None))
loop.run()
except Exception as e:
shell.print_exception(e)
def int_handler(signum, _):
sys.exit(1)
signal.signal(signal.SIGINT, int_handler)
daemon.set_user(config.get('user', None))
loop.run()
if __name__ == '__main__':
main()

View File

@ -173,18 +173,20 @@ class Manager(object):
self._statistics.clear()
def _send_control_data(self, data):
if self._control_client_addr:
try:
self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
if not self._control_client_addr:
return
try:
self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS,
errno.EWOULDBLOCK):
return
else:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
def run(self):
self._loop.run()

View File

@ -23,6 +23,10 @@ import json
import sys
import getopt
import logging
import traceback
from functools import wraps
from shadowsocks.common import to_bytes, to_str, IPNetwork
from shadowsocks import encrypt
@ -53,6 +57,49 @@ def print_exception(e):
traceback.print_exc()
def exception_handle(self_, err_msg=None, exit_code=None,
destroy=False, conn_err=False):
# self_: if function passes self as first arg
def process_exception(e, self=None):
print_exception(e)
if err_msg:
logging.error(err_msg)
if exit_code:
sys.exit(1)
if not self_:
return
if conn_err:
addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(e, addr, port))
if self._config['verbose']:
traceback.print_exc()
if destroy:
self.destroy()
def decorator(func):
if self_:
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
except Exception as e:
process_exception(e, self)
else:
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
process_exception(e)
return wrapper
return decorator
def print_shadowsocks():
version = ''
try:

View File

@ -190,21 +190,23 @@ class TCPRelayHandler(object):
if self._upstream_status != status:
self._upstream_status = status
dirty = True
if dirty:
if self._local_sock:
event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT
if self._upstream_status & WAIT_STATUS_READING:
event |= eventloop.POLL_IN
self._loop.modify(self._local_sock, event)
if self._remote_sock:
event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_READING:
event |= eventloop.POLL_IN
if self._upstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT
self._loop.modify(self._remote_sock, event)
if not dirty:
return
if self._local_sock:
event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT
if self._upstream_status & WAIT_STATUS_READING:
event |= eventloop.POLL_IN
self._loop.modify(self._local_sock, event)
if self._remote_sock:
event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_READING:
event |= eventloop.POLL_IN
if self._upstream_status & WAIT_STATUS_WRITING:
event |= eventloop.POLL_OUT
self._loop.modify(self._remote_sock, event)
def _write_to_sock(self, data, sock):
# write data to sock
@ -247,19 +249,20 @@ class TCPRelayHandler(object):
return True
def _handle_stage_connecting(self, data):
if self._is_local:
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
else:
if not self._is_local:
if self._ota_enable_session:
self._ota_chunk_data(data,
self._data_to_write_to_remote.append)
else:
self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \
self._config['fast_open']:
return
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
if self._config['fast_open'] and not self._fastopen_connected:
# for sslocal and fastopen, we basically wait for data and use
# sendto to connect
try:
@ -293,93 +296,88 @@ class TCPRelayHandler(object):
traceback.print_exc()
self.destroy()
@shell.exception_handle(self_=True, destroy=True, conn_err=True)
def _handle_stage_addr(self, data):
try:
if self._is_local:
cmd = common.ord(data[1])
if cmd == CMD_UDP_ASSOCIATE:
logging.debug('UDP associate')
if self._local_sock.family == socket.AF_INET6:
header = b'\x05\x00\x00\x04'
else:
header = b'\x05\x00\x00\x01'
addr, port = self._local_sock.getsockname()[:2]
addr_to_send = socket.inet_pton(self._local_sock.family,
addr)
port_to_send = struct.pack('>H', port)
self._write_to_sock(header + addr_to_send + port_to_send,
self._local_sock)
self._stage = STAGE_UDP_ASSOC
# just wait for the client to disconnect
return
elif cmd == CMD_CONNECT:
# just trim VER CMD RSV
data = data[3:]
if self._is_local:
cmd = common.ord(data[1])
if cmd == CMD_UDP_ASSOCIATE:
logging.debug('UDP associate')
if self._local_sock.family == socket.AF_INET6:
header = b'\x05\x00\x00\x04'
else:
logging.error('unknown command %d', cmd)
header = b'\x05\x00\x00\x01'
addr, port = self._local_sock.getsockname()[:2]
addr_to_send = socket.inet_pton(self._local_sock.family,
addr)
port_to_send = struct.pack('>H', port)
self._write_to_sock(header + addr_to_send + port_to_send,
self._local_sock)
self._stage = STAGE_UDP_ASSOC
# just wait for the client to disconnect
return
elif cmd == CMD_CONNECT:
# just trim VER CMD RSV
data = data[3:]
else:
logging.error('unknown command %d', cmd)
self.destroy()
return
header_result = parse_header(data)
if header_result is None:
raise Exception('can not parse header')
addrtype, remote_addr, remote_port, header_length = header_result
logging.info('connecting %s:%d from %s:%d' %
(common.to_str(remote_addr), remote_port,
self._client_address[0], self._client_address[1]))
if self._is_local is False:
# spec https://shadowsocks.org/en/spec/one-time-auth.html
self._ota_enable_session = addrtype & ADDRTYPE_AUTH
if self._ota_enable and not self._ota_enable_session:
logging.warn('client one time auth is required')
return
if self._ota_enable_session:
if len(data) < header_length + ONETIMEAUTH_BYTES:
logging.warn('one time auth header is too short')
return None
offset = header_length + ONETIMEAUTH_BYTES
_hash = data[header_length: offset]
_data = data[:header_length]
key = self._encryptor.decipher_iv + self._encryptor.key
if onetimeauth_verify(_hash, _data, key) is False:
logging.warn('one time auth fail')
self.destroy()
return
header_result = parse_header(data)
if header_result is None:
raise Exception('can not parse header')
addrtype, remote_addr, remote_port, header_length = header_result
logging.info('connecting %s:%d from %s:%d' %
(common.to_str(remote_addr), remote_port,
self._client_address[0], self._client_address[1]))
if self._is_local is False:
# spec https://shadowsocks.org/en/spec/one-time-auth.html
self._ota_enable_session = addrtype & ADDRTYPE_AUTH
if self._ota_enable and not self._ota_enable_session:
logging.warn('client one time auth is required')
return
if self._ota_enable_session:
if len(data) < header_length + ONETIMEAUTH_BYTES:
logging.warn('one time auth header is too short')
return None
offset = header_length + ONETIMEAUTH_BYTES
_hash = data[header_length: offset]
_data = data[:header_length]
key = self._encryptor.decipher_iv + self._encryptor.key
if onetimeauth_verify(_hash, _data, key) is False:
logging.warn('one time auth fail')
self.destroy()
return
header_length += ONETIMEAUTH_BYTES
self._remote_address = (common.to_str(remote_addr), remote_port)
# pause reading
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
self._stage = STAGE_DNS
if self._is_local:
# forward address to remote
self._write_to_sock((b'\x05\x00\x00\x01'
b'\x00\x00\x00\x00\x10\x10'),
self._local_sock)
# spec https://shadowsocks.org/en/spec/one-time-auth.html
# ATYP & 0x10 == 1, then OTA is enabled.
if self._ota_enable_session:
data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
key = self._encryptor.cipher_iv + self._encryptor.key
data += onetimeauth_gen(data, key)
data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved)
else:
if self._ota_enable_session:
data = data[header_length:]
self._ota_chunk_data(data,
self._data_to_write_to_remote.append)
elif len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved)
except Exception as e:
self._log_error(e)
if self._config['verbose']:
traceback.print_exc()
self.destroy()
header_length += ONETIMEAUTH_BYTES
self._remote_address = (common.to_str(remote_addr), remote_port)
# pause reading
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
self._stage = STAGE_DNS
if self._is_local:
# forward address to remote
self._write_to_sock((b'\x05\x00\x00\x01'
b'\x00\x00\x00\x00\x10\x10'),
self._local_sock)
# spec https://shadowsocks.org/en/spec/one-time-auth.html
# ATYP & 0x10 == 1, then OTA is enabled.
if self._ota_enable_session:
data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
key = self._encryptor.cipher_iv + self._encryptor.key
data += onetimeauth_gen(data, key)
data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved)
else:
if self._ota_enable_session:
data = data[header_length:]
self._ota_chunk_data(data,
self._data_to_write_to_remote.append)
elif len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved)
def _create_remote_socket(self, ip, port):
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
@ -398,51 +396,50 @@ class TCPRelayHandler(object):
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
return remote_sock
@shell.exception_handle(self_=True)
def _handle_dns_resolved(self, result, error):
if error:
self._log_error(error)
addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(error, addr, port))
self.destroy()
return
if not (result and result[1]):
self.destroy()
return
if result and result[1]:
ip = result[1]
try:
self._stage = STAGE_CONNECTING
remote_addr = ip
if self._is_local:
remote_port = self._chosen_server[1]
else:
remote_port = self._remote_address[1]
if self._is_local and self._config['fast_open']:
# for fastopen:
# wait for more data arrive and send them in one SYN
self._stage = STAGE_CONNECTING
# we don't have to wait for remote since it's not
# created
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
# TODO when there is already data in this packet
else:
# else do connect
remote_sock = self._create_remote_socket(remote_addr,
remote_port)
try:
remote_sock.connect((remote_addr, remote_port))
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == \
errno.EINPROGRESS:
pass
self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server)
self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
return
except Exception as e:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
self.destroy()
ip = result[1]
self._stage = STAGE_CONNECTING
remote_addr = ip
if self._is_local:
remote_port = self._chosen_server[1]
else:
remote_port = self._remote_address[1]
if self._is_local and self._config['fast_open']:
# for fastopen:
# wait for more data arrive and send them in one SYN
self._stage = STAGE_CONNECTING
# we don't have to wait for remote since it's not
# created
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
# TODO when there is already data in this packet
else:
# else do connect
remote_sock = self._create_remote_socket(remote_addr,
remote_port)
try:
remote_sock.connect((remote_addr, remote_port))
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == \
errno.EINPROGRESS:
pass
self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server)
self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
def _write_to_sock_remote(self, data):
self._write_to_sock(data, self._remote_sock)
@ -661,10 +658,6 @@ class TCPRelayHandler(object):
else:
logging.warn('unknown socket')
def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1]))
def destroy(self):
# destroy the handler and release any resources
# promises: