* make tcprelay.py less nested

* import traceback module at the top

* make loops in DNSResolver less nested

* make manager.py less nested

* introduce exception_handle decorator

make try/except block more clean

* apply exception_handle decorator to tcprelay

* quote condition judgement

* pep8 fix
This commit is contained in:
ahxxm 2016-10-07 12:30:17 +08:00 committed by mengskysama
parent 5c11527e1b
commit 5cd9f04948
6 changed files with 248 additions and 203 deletions

View file

@ -276,10 +276,13 @@ class DNSResolver(object):
content = f.readlines() content = f.readlines()
for line in content: for line in content:
line = line.strip() line = line.strip()
if line: if not (line and line.startswith(b'nameserver')):
if line.startswith(b'nameserver'): continue
parts = line.split() parts = line.split()
if len(parts) >= 2: if len(parts) < 2:
continue
server = parts[1] server = parts[1]
if common.is_ip(server) == socket.AF_INET: if common.is_ip(server) == socket.AF_INET:
if type(server) != str: if type(server) != str:
@ -299,9 +302,13 @@ class DNSResolver(object):
for line in f.readlines(): for line in f.readlines():
line = line.strip() line = line.strip()
parts = line.split() parts = line.split()
if len(parts) >= 2: if len(parts) < 2:
continue
ip = parts[0] ip = parts[0]
if common.is_ip(ip): if not common.is_ip(ip):
continue
for i in range(1, len(parts)): for i in range(1, len(parts)):
hostname = parts[i] hostname = parts[i]
if hostname: if hostname:

View file

@ -25,6 +25,7 @@ import os
import time import time
import socket import socket
import select import select
import traceback
import errno import errno
import logging import logging
from collections import defaultdict from collections import defaultdict
@ -204,7 +205,6 @@ class EventLoop(object):
logging.debug('poll:%s', e) logging.debug('poll:%s', e)
else: else:
logging.error('poll:%s', e) logging.error('poll:%s', e)
import traceback
traceback.print_exc() traceback.print_exc()
continue continue

View file

@ -27,6 +27,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns
@shell.exception_handle(self_=False, exit_code=1)
def main(): def main():
shell.check_python() shell.check_python()
@ -37,10 +38,8 @@ def main():
os.chdir(p) os.chdir(p)
config = shell.get_config(True) config = shell.get_config(True)
daemon.daemon_exec(config) daemon.daemon_exec(config)
try:
logging.info("starting local at %s:%d" % logging.info("starting local at %s:%d" %
(config['local_address'], config['local_port'])) (config['local_address'], config['local_port']))
@ -64,9 +63,6 @@ def main():
daemon.set_user(config.get('user', None)) daemon.set_user(config.get('user', None))
loop.run() loop.run()
except Exception as e:
shell.print_exception(e)
sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View file

@ -173,7 +173,9 @@ class Manager(object):
self._statistics.clear() self._statistics.clear()
def _send_control_data(self, data): def _send_control_data(self, data):
if self._control_client_addr: if not self._control_client_addr:
return
try: try:
self._control_socket.sendto(data, self._control_client_addr) self._control_socket.sendto(data, self._control_client_addr)
except (socket.error, OSError, IOError) as e: except (socket.error, OSError, IOError) as e:

View file

@ -23,6 +23,10 @@ import json
import sys import sys
import getopt import getopt
import logging import logging
import traceback
from functools import wraps
from shadowsocks.common import to_bytes, to_str, IPNetwork from shadowsocks.common import to_bytes, to_str, IPNetwork
from shadowsocks import encrypt from shadowsocks import encrypt
@ -53,6 +57,49 @@ def print_exception(e):
traceback.print_exc() traceback.print_exc()
def exception_handle(self_, err_msg=None, exit_code=None,
destroy=False, conn_err=False):
# self_: if function passes self as first arg
def process_exception(e, self=None):
print_exception(e)
if err_msg:
logging.error(err_msg)
if exit_code:
sys.exit(1)
if not self_:
return
if conn_err:
addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(e, addr, port))
if self._config['verbose']:
traceback.print_exc()
if destroy:
self.destroy()
def decorator(func):
if self_:
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
except Exception as e:
process_exception(e, self)
else:
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
process_exception(e)
return wrapper
return decorator
def print_shadowsocks(): def print_shadowsocks():
version = '' version = ''
try: try:

View file

@ -190,7 +190,9 @@ class TCPRelayHandler(object):
if self._upstream_status != status: if self._upstream_status != status:
self._upstream_status = status self._upstream_status = status
dirty = True dirty = True
if dirty: if not dirty:
return
if self._local_sock: if self._local_sock:
event = eventloop.POLL_ERR event = eventloop.POLL_ERR
if self._downstream_status & WAIT_STATUS_WRITING: if self._downstream_status & WAIT_STATUS_WRITING:
@ -247,19 +249,20 @@ class TCPRelayHandler(object):
return True return True
def _handle_stage_connecting(self, data): def _handle_stage_connecting(self, data):
if self._is_local: if not self._is_local:
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
else:
if self._ota_enable_session: if self._ota_enable_session:
self._ota_chunk_data(data, self._ota_chunk_data(data,
self._data_to_write_to_remote.append) self._data_to_write_to_remote.append)
else: else:
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
if self._is_local and not self._fastopen_connected and \ return
self._config['fast_open']:
if self._ota_enable_session:
data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data)
if self._config['fast_open'] and not self._fastopen_connected:
# for sslocal and fastopen, we basically wait for data and use # for sslocal and fastopen, we basically wait for data and use
# sendto to connect # sendto to connect
try: try:
@ -293,8 +296,8 @@ class TCPRelayHandler(object):
traceback.print_exc() traceback.print_exc()
self.destroy() self.destroy()
@shell.exception_handle(self_=True, destroy=True, conn_err=True)
def _handle_stage_addr(self, data): def _handle_stage_addr(self, data):
try:
if self._is_local: if self._is_local:
cmd = common.ord(data[1]) cmd = common.ord(data[1])
if cmd == CMD_UDP_ASSOCIATE: if cmd == CMD_UDP_ASSOCIATE:
@ -375,11 +378,6 @@ class TCPRelayHandler(object):
# notice here may go into _handle_dns_resolved directly # notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(remote_addr, self._dns_resolver.resolve(remote_addr,
self._handle_dns_resolved) self._handle_dns_resolved)
except Exception as e:
self._log_error(e)
if self._config['verbose']:
traceback.print_exc()
self.destroy()
def _create_remote_socket(self, ip, port): def _create_remote_socket(self, ip, port):
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
@ -398,14 +396,19 @@ class TCPRelayHandler(object):
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
return remote_sock return remote_sock
@shell.exception_handle(self_=True)
def _handle_dns_resolved(self, result, error): def _handle_dns_resolved(self, result, error):
if error: if error:
self._log_error(error) addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(error, addr, port))
self.destroy() self.destroy()
return return
if result and result[1]: if not (result and result[1]):
self.destroy()
return
ip = result[1] ip = result[1]
try:
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
remote_addr = ip remote_addr = ip
if self._is_local: if self._is_local:
@ -437,12 +440,6 @@ class TCPRelayHandler(object):
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
return
except Exception as e:
shell.print_exception(e)
if self._config['verbose']:
traceback.print_exc()
self.destroy()
def _write_to_sock_remote(self, data): def _write_to_sock_remote(self, data):
self._write_to_sock(data, self._remote_sock) self._write_to_sock(data, self._remote_sock)
@ -661,10 +658,6 @@ class TCPRelayHandler(object):
else: else:
logging.warn('unknown socket') logging.warn('unknown socket')
def _log_error(self, e):
logging.error('%s when handling connection from %s:%d' %
(e, self._client_address[0], self._client_address[1]))
def destroy(self): def destroy(self):
# destroy the handler and release any resources # destroy the handler and release any resources
# promises: # promises: