apply exception_handle decorator to tcprelay
This commit is contained in:
parent
52556d5e48
commit
bcb528901d
2 changed files with 130 additions and 134 deletions
|
@ -23,6 +23,7 @@ import json
|
|||
import sys
|
||||
import getopt
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from functools import wraps
|
||||
|
||||
|
@ -56,20 +57,28 @@ def print_exception(e):
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
def exception_handle(self_, err_msg=None, exit_code=None):
|
||||
"""
|
||||
:param self_: if function passes self as first arg
|
||||
:param err_msg:
|
||||
:param exit_code:
|
||||
:return:
|
||||
"""
|
||||
def process_exception(e):
|
||||
def exception_handle(self_, err_msg=None, exit_code=None,
|
||||
destroy=False, conn_err=False):
|
||||
# self_: if function passes self as first arg
|
||||
|
||||
def process_exception(self, e):
|
||||
print_exception(e)
|
||||
if err_msg:
|
||||
logging.error(err_msg)
|
||||
if exit_code:
|
||||
sys.exit(1)
|
||||
|
||||
if not self_:
|
||||
return
|
||||
|
||||
if conn_err:
|
||||
logging.error('%s when handling connection from %s:%d' %
|
||||
(e, self._client_address[0], self._client_address[1]))
|
||||
if self._config['verbose']:
|
||||
traceback.print_exc()
|
||||
if destroy:
|
||||
self.destroy()
|
||||
|
||||
def decorator(func):
|
||||
if self_:
|
||||
@wraps(func)
|
||||
|
@ -77,14 +86,14 @@ def exception_handle(self_, err_msg=None, exit_code=None):
|
|||
try:
|
||||
func(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
process_exception(e)
|
||||
process_exception(self, e)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
process_exception(e)
|
||||
process_exception(self, e)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
|
|
@ -296,93 +296,88 @@ class TCPRelayHandler(object):
|
|||
traceback.print_exc()
|
||||
self.destroy()
|
||||
|
||||
@shell.exception_handle(self_=True, destroy=True, conn_err=True)
|
||||
def _handle_stage_addr(self, data):
|
||||
try:
|
||||
if self._is_local:
|
||||
cmd = common.ord(data[1])
|
||||
if cmd == CMD_UDP_ASSOCIATE:
|
||||
logging.debug('UDP associate')
|
||||
if self._local_sock.family == socket.AF_INET6:
|
||||
header = b'\x05\x00\x00\x04'
|
||||
else:
|
||||
header = b'\x05\x00\x00\x01'
|
||||
addr, port = self._local_sock.getsockname()[:2]
|
||||
addr_to_send = socket.inet_pton(self._local_sock.family,
|
||||
addr)
|
||||
port_to_send = struct.pack('>H', port)
|
||||
self._write_to_sock(header + addr_to_send + port_to_send,
|
||||
self._local_sock)
|
||||
self._stage = STAGE_UDP_ASSOC
|
||||
# just wait for the client to disconnect
|
||||
return
|
||||
elif cmd == CMD_CONNECT:
|
||||
# just trim VER CMD RSV
|
||||
data = data[3:]
|
||||
if self._is_local:
|
||||
cmd = common.ord(data[1])
|
||||
if cmd == CMD_UDP_ASSOCIATE:
|
||||
logging.debug('UDP associate')
|
||||
if self._local_sock.family == socket.AF_INET6:
|
||||
header = b'\x05\x00\x00\x04'
|
||||
else:
|
||||
logging.error('unknown command %d', cmd)
|
||||
header = b'\x05\x00\x00\x01'
|
||||
addr, port = self._local_sock.getsockname()[:2]
|
||||
addr_to_send = socket.inet_pton(self._local_sock.family,
|
||||
addr)
|
||||
port_to_send = struct.pack('>H', port)
|
||||
self._write_to_sock(header + addr_to_send + port_to_send,
|
||||
self._local_sock)
|
||||
self._stage = STAGE_UDP_ASSOC
|
||||
# just wait for the client to disconnect
|
||||
return
|
||||
elif cmd == CMD_CONNECT:
|
||||
# just trim VER CMD RSV
|
||||
data = data[3:]
|
||||
else:
|
||||
logging.error('unknown command %d', cmd)
|
||||
self.destroy()
|
||||
return
|
||||
header_result = parse_header(data)
|
||||
if header_result is None:
|
||||
raise Exception('can not parse header')
|
||||
addrtype, remote_addr, remote_port, header_length = header_result
|
||||
logging.info('connecting %s:%d from %s:%d' %
|
||||
(common.to_str(remote_addr), remote_port,
|
||||
self._client_address[0], self._client_address[1]))
|
||||
if self._is_local is False:
|
||||
# spec https://shadowsocks.org/en/spec/one-time-auth.html
|
||||
self._ota_enable_session = addrtype & ADDRTYPE_AUTH
|
||||
if self._ota_enable and not self._ota_enable_session:
|
||||
logging.warn('client one time auth is required')
|
||||
return
|
||||
if self._ota_enable_session:
|
||||
if len(data) < header_length + ONETIMEAUTH_BYTES:
|
||||
logging.warn('one time auth header is too short')
|
||||
return None
|
||||
offset = header_length + ONETIMEAUTH_BYTES
|
||||
_hash = data[header_length: offset]
|
||||
_data = data[:header_length]
|
||||
key = self._encryptor.decipher_iv + self._encryptor.key
|
||||
if onetimeauth_verify(_hash, _data, key) is False:
|
||||
logging.warn('one time auth fail')
|
||||
self.destroy()
|
||||
return
|
||||
header_result = parse_header(data)
|
||||
if header_result is None:
|
||||
raise Exception('can not parse header')
|
||||
addrtype, remote_addr, remote_port, header_length = header_result
|
||||
logging.info('connecting %s:%d from %s:%d' %
|
||||
(common.to_str(remote_addr), remote_port,
|
||||
self._client_address[0], self._client_address[1]))
|
||||
if self._is_local is False:
|
||||
# spec https://shadowsocks.org/en/spec/one-time-auth.html
|
||||
self._ota_enable_session = addrtype & ADDRTYPE_AUTH
|
||||
if self._ota_enable and not self._ota_enable_session:
|
||||
logging.warn('client one time auth is required')
|
||||
return
|
||||
if self._ota_enable_session:
|
||||
if len(data) < header_length + ONETIMEAUTH_BYTES:
|
||||
logging.warn('one time auth header is too short')
|
||||
return None
|
||||
offset = header_length + ONETIMEAUTH_BYTES
|
||||
_hash = data[header_length: offset]
|
||||
_data = data[:header_length]
|
||||
key = self._encryptor.decipher_iv + self._encryptor.key
|
||||
if onetimeauth_verify(_hash, _data, key) is False:
|
||||
logging.warn('one time auth fail')
|
||||
self.destroy()
|
||||
return
|
||||
header_length += ONETIMEAUTH_BYTES
|
||||
self._remote_address = (common.to_str(remote_addr), remote_port)
|
||||
# pause reading
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
|
||||
self._stage = STAGE_DNS
|
||||
if self._is_local:
|
||||
# forward address to remote
|
||||
self._write_to_sock((b'\x05\x00\x00\x01'
|
||||
b'\x00\x00\x00\x00\x10\x10'),
|
||||
self._local_sock)
|
||||
# spec https://shadowsocks.org/en/spec/one-time-auth.html
|
||||
# ATYP & 0x10 == 1, then OTA is enabled.
|
||||
if self._ota_enable_session:
|
||||
data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
|
||||
key = self._encryptor.cipher_iv + self._encryptor.key
|
||||
data += onetimeauth_gen(data, key)
|
||||
data_to_send = self._encryptor.encrypt(data)
|
||||
self._data_to_write_to_remote.append(data_to_send)
|
||||
# notice here may go into _handle_dns_resolved directly
|
||||
self._dns_resolver.resolve(self._chosen_server[0],
|
||||
self._handle_dns_resolved)
|
||||
else:
|
||||
if self._ota_enable_session:
|
||||
data = data[header_length:]
|
||||
self._ota_chunk_data(data,
|
||||
self._data_to_write_to_remote.append)
|
||||
elif len(data) > header_length:
|
||||
self._data_to_write_to_remote.append(data[header_length:])
|
||||
# notice here may go into _handle_dns_resolved directly
|
||||
self._dns_resolver.resolve(remote_addr,
|
||||
self._handle_dns_resolved)
|
||||
except Exception as e:
|
||||
self._log_error(e)
|
||||
if self._config['verbose']:
|
||||
traceback.print_exc()
|
||||
self.destroy()
|
||||
header_length += ONETIMEAUTH_BYTES
|
||||
self._remote_address = (common.to_str(remote_addr), remote_port)
|
||||
# pause reading
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_WRITING)
|
||||
self._stage = STAGE_DNS
|
||||
if self._is_local:
|
||||
# forward address to remote
|
||||
self._write_to_sock((b'\x05\x00\x00\x01'
|
||||
b'\x00\x00\x00\x00\x10\x10'),
|
||||
self._local_sock)
|
||||
# spec https://shadowsocks.org/en/spec/one-time-auth.html
|
||||
# ATYP & 0x10 == 1, then OTA is enabled.
|
||||
if self._ota_enable_session:
|
||||
data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:]
|
||||
key = self._encryptor.cipher_iv + self._encryptor.key
|
||||
data += onetimeauth_gen(data, key)
|
||||
data_to_send = self._encryptor.encrypt(data)
|
||||
self._data_to_write_to_remote.append(data_to_send)
|
||||
# notice here may go into _handle_dns_resolved directly
|
||||
self._dns_resolver.resolve(self._chosen_server[0],
|
||||
self._handle_dns_resolved)
|
||||
else:
|
||||
if self._ota_enable_session:
|
||||
data = data[header_length:]
|
||||
self._ota_chunk_data(data,
|
||||
self._data_to_write_to_remote.append)
|
||||
elif len(data) > header_length:
|
||||
self._data_to_write_to_remote.append(data[header_length:])
|
||||
# notice here may go into _handle_dns_resolved directly
|
||||
self._dns_resolver.resolve(remote_addr,
|
||||
self._handle_dns_resolved)
|
||||
|
||||
def _create_remote_socket(self, ip, port):
|
||||
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
|
||||
|
@ -401,9 +396,11 @@ class TCPRelayHandler(object):
|
|||
remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||
return remote_sock
|
||||
|
||||
@shell.exception_handle(self_=True)
|
||||
def _handle_dns_resolved(self, result, error):
|
||||
if error:
|
||||
self._log_error(error)
|
||||
logging.error('%s when handling connection from %s:%d' %
|
||||
(error, self._client_address[0], self._client_address[1]))
|
||||
self.destroy()
|
||||
return
|
||||
if not (result and result[1]):
|
||||
|
@ -411,43 +408,37 @@ class TCPRelayHandler(object):
|
|||
return
|
||||
|
||||
ip = result[1]
|
||||
try:
|
||||
self._stage = STAGE_CONNECTING
|
||||
remote_addr = ip
|
||||
if self._is_local:
|
||||
remote_port = self._chosen_server[1]
|
||||
else:
|
||||
remote_port = self._remote_address[1]
|
||||
self._stage = STAGE_CONNECTING
|
||||
remote_addr = ip
|
||||
if self._is_local:
|
||||
remote_port = self._chosen_server[1]
|
||||
else:
|
||||
remote_port = self._remote_address[1]
|
||||
|
||||
if self._is_local and self._config['fast_open']:
|
||||
# for fastopen:
|
||||
# wait for more data arrive and send them in one SYN
|
||||
self._stage = STAGE_CONNECTING
|
||||
# we don't have to wait for remote since it's not
|
||||
# created
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
|
||||
# TODO when there is already data in this packet
|
||||
else:
|
||||
# else do connect
|
||||
remote_sock = self._create_remote_socket(remote_addr,
|
||||
remote_port)
|
||||
try:
|
||||
remote_sock.connect((remote_addr, remote_port))
|
||||
except (OSError, IOError) as e:
|
||||
if eventloop.errno_from_exception(e) == \
|
||||
errno.EINPROGRESS:
|
||||
pass
|
||||
self._loop.add(remote_sock,
|
||||
eventloop.POLL_ERR | eventloop.POLL_OUT,
|
||||
self._server)
|
||||
self._stage = STAGE_CONNECTING
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
||||
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
||||
return
|
||||
except Exception as e:
|
||||
shell.print_exception(e)
|
||||
if self._config['verbose']:
|
||||
traceback.print_exc()
|
||||
if self._is_local and self._config['fast_open']:
|
||||
# for fastopen:
|
||||
# wait for more data arrive and send them in one SYN
|
||||
self._stage = STAGE_CONNECTING
|
||||
# we don't have to wait for remote since it's not
|
||||
# created
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_READING)
|
||||
# TODO when there is already data in this packet
|
||||
else:
|
||||
# else do connect
|
||||
remote_sock = self._create_remote_socket(remote_addr,
|
||||
remote_port)
|
||||
try:
|
||||
remote_sock.connect((remote_addr, remote_port))
|
||||
except (OSError, IOError) as e:
|
||||
if eventloop.errno_from_exception(e) == \
|
||||
errno.EINPROGRESS:
|
||||
pass
|
||||
self._loop.add(remote_sock,
|
||||
eventloop.POLL_ERR | eventloop.POLL_OUT,
|
||||
self._server)
|
||||
self._stage = STAGE_CONNECTING
|
||||
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
|
||||
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
|
||||
|
||||
def _write_to_sock_remote(self, data):
|
||||
self._write_to_sock(data, self._remote_sock)
|
||||
|
@ -666,10 +657,6 @@ class TCPRelayHandler(object):
|
|||
else:
|
||||
logging.warn('unknown socket')
|
||||
|
||||
def _log_error(self, e):
|
||||
logging.error('%s when handling connection from %s:%d' %
|
||||
(e, self._client_address[0], self._client_address[1]))
|
||||
|
||||
def destroy(self):
|
||||
# destroy the handler and release any resources
|
||||
# promises:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue