This commit is contained in:
jsy 2016-01-11 23:18:16 +08:00
parent ee391c7773
commit a4a87eb127
2 changed files with 100 additions and 90 deletions

View file

@ -109,13 +109,13 @@ class TCPRelayHandler(object):
self._encryptor = encrypt.Encryptor(config['password'], self._encryptor = encrypt.Encryptor(config['password'],
config['method']) config['method'])
if 'one_time_auth' in config and config['one_time_auth']: if 'one_time_auth' in config and config['one_time_auth']:
self._one_time_auth_enable = True self._ota_enable = True
else: else:
self._one_time_auth_enable = False self._ota_enable = False
self._one_time_auth_buff_head = '' self._ota_buff_head = ''
self._one_time_auth_buff_data = '' self._ota_buff_data = ''
self._one_time_auth_len = 0 self._ota_len = 0
self._one_time_auth_chunk_idx = 0 self._ota_chunk_idx = 0
self._fastopen_connected = False self._fastopen_connected = False
self._data_to_write_to_local = [] self._data_to_write_to_local = []
self._data_to_write_to_remote = [] self._data_to_write_to_remote = []
@ -233,14 +233,14 @@ class TCPRelayHandler(object):
def _handle_stage_connecting(self, data): def _handle_stage_connecting(self, data):
if self._is_local: if self._is_local:
if self._one_time_auth_enable: if self._ota_enable:
data = self._one_time_auth_chunk_data_gen(data) data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data) data = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data) self._data_to_write_to_remote.append(data)
else: else:
if self._one_time_auth_enable: if self._ota_enable:
self._one_time_auth_chunk_data(data, self._ota_chunk_data(data,
self._data_to_write_to_remote.append) self._data_to_write_to_remote.append)
if self._is_local and not self._fastopen_connected and \ if self._is_local and not self._fastopen_connected and \
self._config['fast_open']: self._config['fast_open']:
# for sslocal and fastopen, we basically wait for data and use # for sslocal and fastopen, we basically wait for data and use
@ -254,7 +254,8 @@ class TCPRelayHandler(object):
self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) self._loop.add(remote_sock, eventloop.POLL_ERR, self._server)
data = b''.join(self._data_to_write_to_remote) data = b''.join(self._data_to_write_to_remote)
l = len(data) l = len(data)
s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) s = remote_sock.sendto(data, MSG_FASTOPEN,
self._chosen_server)
if s < l: if s < l:
data = data[s:] data = data[s:]
self._data_to_write_to_remote = [data] self._data_to_write_to_remote = [data]
@ -310,13 +311,15 @@ class TCPRelayHandler(object):
self._client_address[0], self._client_address[1])) self._client_address[0], self._client_address[1]))
if self._is_local is False: if self._is_local is False:
# spec https://shadowsocks.org/en/spec/one-time-auth.html # spec https://shadowsocks.org/en/spec/one-time-auth.html
if self._one_time_auth_enable or addrtype & ADDRTYPE_AUTH: if self._ota_enable or addrtype & ADDRTYPE_AUTH:
if len(data) < header_length + ONETIMEAUTH_BYTES: if len(data) < header_length + ONETIMEAUTH_BYTES:
logging.warn('one time auth header is too short') logging.warn('one time auth header is too short')
return None return None
if onetimeauth_verify(data[header_length: header_length+ONETIMEAUTH_BYTES], offset = header_length + ONETIMEAUTH_BYTES
data[:header_length], _hash = data[header_length: offset]
self._encryptor.decipher_iv + self._encryptor.key) is False: _data = data[:header_length]
key = self._encryptor.decipher_iv + self._encryptor.key
if onetimeauth_verify(_hash, _data, key) is False:
logging.warn('one time auth fail') logging.warn('one time auth fail')
self.destroy() self.destroy()
header_length += ONETIMEAUTH_BYTES header_length += ONETIMEAUTH_BYTES
@ -331,19 +334,20 @@ class TCPRelayHandler(object):
self._local_sock) self._local_sock)
# spec https://shadowsocks.org/en/spec/one-time-auth.html # spec https://shadowsocks.org/en/spec/one-time-auth.html
# ATYP & 0x10 == 1, then OTA is enabled. # ATYP & 0x10 == 1, then OTA is enabled.
if self._one_time_auth_enable: if self._ota_enable:
data = chr(ord(data[0]) | ADDRTYPE_AUTH) + data[1:] data = chr(ord(data[0]) | ADDRTYPE_AUTH) + data[1:]
data += onetimeauth_gen(data, self._encryptor.cipher_iv + self._encryptor.key) key = self._encryptor.cipher_iv + self._encryptor.key
data += onetimeauth_gen(data, key)
data_to_send = self._encryptor.encrypt(data) data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send) self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly # notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._chosen_server[0], self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved) self._handle_dns_resolved)
else: else:
if self._one_time_auth_enable: if self._ota_enable:
data = data[header_length:] data = data[header_length:]
self._one_time_auth_chunk_data(data, self._ota_chunk_data(data,
self._data_to_write_to_remote.append) self._data_to_write_to_remote.append)
elif len(data) > header_length: elif len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:]) self._data_to_write_to_remote.append(data[header_length:])
# notice here may go into _handle_dns_resolved directly # notice here may go into _handle_dns_resolved directly
@ -377,97 +381,99 @@ class TCPRelayHandler(object):
self._log_error(error) self._log_error(error)
self.destroy() self.destroy()
return return
if result: if result and result[1]:
ip = result[1] ip = result[1]
if ip: try:
try: self._stage = STAGE_CONNECTING
self._stage = STAGE_CONNECTING remote_addr = ip
remote_addr = ip if self._is_local:
if self._is_local: remote_port = self._chosen_server[1]
remote_port = self._chosen_server[1] else:
else: remote_port = self._remote_address[1]
remote_port = self._remote_address[1]
if self._is_local and self._config['fast_open']: if self._is_local and self._config['fast_open']:
# for fastopen: # for fastopen:
# wait for more data to arrive and send them in one SYN # wait for more data arrive and send them in one SYN
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
# we don't have to wait for remote since it's not # we don't have to wait for remote since it's not
# created # created
self._update_stream(STREAM_UP, WAIT_STATUS_READING) self._update_stream(STREAM_UP, WAIT_STATUS_READING)
# TODO when there is already data in this packet # TODO when there is already data in this packet
else: else:
# else do connect # else do connect
remote_sock = self._create_remote_socket(remote_addr, remote_sock = self._create_remote_socket(remote_addr,
remote_port) remote_port)
try: try:
remote_sock.connect((remote_addr, remote_port)) remote_sock.connect((remote_addr, remote_port))
except (OSError, IOError) as e: except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == \ if eventloop.errno_from_exception(e) == \
errno.EINPROGRESS: errno.EINPROGRESS:
pass pass
self._loop.add(remote_sock, self._loop.add(remote_sock,
eventloop.POLL_ERR | eventloop.POLL_OUT, eventloop.POLL_ERR | eventloop.POLL_OUT,
self._server) self._server)
self._stage = STAGE_CONNECTING self._stage = STAGE_CONNECTING
self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING)
self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) self._update_stream(STREAM_DOWN, WAIT_STATUS_READING)
return return
except Exception as e: except Exception as e:
shell.print_exception(e) shell.print_exception(e)
if self._config['verbose']: if self._config['verbose']:
traceback.print_exc() traceback.print_exc()
self.destroy() self.destroy()
def _write_to_sock_remote(self, data): def _write_to_sock_remote(self, data):
self._write_to_sock(data, self._remote_sock) self._write_to_sock(data, self._remote_sock)
def _one_time_auth_chunk_data(self, data, data_cb): def _ota_chunk_data(self, data, data_cb):
# spec https://shadowsocks.org/en/spec/one-time-auth.html # spec https://shadowsocks.org/en/spec/one-time-auth.html
while len(data) > 0: while len(data) > 0:
if self._one_time_auth_len == 0: if self._ota_len == 0:
# get DATA.LEN + HMAC-SHA1 # get DATA.LEN + HMAC-SHA1
length = ONETIMEAUTH_CHUNK_BYTES - len(self._one_time_auth_buff_head) length = ONETIMEAUTH_CHUNK_BYTES - len(self._ota_buff_head)
self._one_time_auth_buff_head += data[:length] self._ota_buff_head += data[:length]
data = data[length:] data = data[length:]
if len(self._one_time_auth_buff_head) < ONETIMEAUTH_CHUNK_BYTES: if len(self._ota_buff_head) < ONETIMEAUTH_CHUNK_BYTES:
# wait more data # wait more data
return return
self._one_time_auth_len = struct.unpack('>H', data_len = self._ota_buff_head[:ONETIMEAUTH_CHUNK_DATA_LEN]
self._one_time_auth_buff_head[:ONETIMEAUTH_CHUNK_DATA_LEN])[0] self._ota_len = struct.unpack('>H', data_len)[0]
length = min(self._one_time_auth_len, len(data)) length = min(self._ota_len, len(data))
self._one_time_auth_buff_data += data[:length] self._ota_buff_data += data[:length]
data = data[length:] data = data[length:]
if len(self._one_time_auth_buff_data) == self._one_time_auth_len: if len(self._ota_buff_data) == self._ota_len:
# get a chunk data # get a chunk data
if onetimeauth_verify(self._one_time_auth_buff_head[ONETIMEAUTH_CHUNK_DATA_LEN:], _hash = self._ota_buff_head[ONETIMEAUTH_CHUNK_DATA_LEN:]
self._one_time_auth_buff_data, _data = self._ota_buff_data
self._encryptor.decipher_iv + struct.pack('>I', self._one_time_auth_chunk_idx)) \ index = struct.pack('>I', self._ota_chunk_idx)
is False: key = self._encryptor.decipher_iv + index
if onetimeauth_verify(_hash, _data, key) is False:
logging.warn('one time auth fail, drop chunk !') logging.warn('one time auth fail, drop chunk !')
else: else:
data_cb(self._one_time_auth_buff_data) data_cb(self._ota_buff_data)
self._one_time_auth_chunk_idx += 1 self._ota_chunk_idx += 1
self._one_time_auth_buff_head = '' self._ota_buff_head = ''
self._one_time_auth_buff_data = '' self._ota_buff_data = ''
self._one_time_auth_len = 0 self._ota_len = 0
return return
def _one_time_auth_chunk_data_gen(self, data): def _ota_chunk_data_gen(self, data):
data_len = struct.pack(">H", len(data)) data_len = struct.pack(">H", len(data))
sha110 = onetimeauth_gen(data, self._encryptor.cipher_iv + struct.pack('>I', self._one_time_auth_chunk_idx)) index = struct.pack('>I', self._ota_chunk_idx)
self._one_time_auth_chunk_idx += 1 key = self._encryptor.cipher_iv + index
sha110 = onetimeauth_gen(data, key)
self._ota_chunk_idx += 1
return data_len + sha110 + data return data_len + sha110 + data
def _handle_stage_stream(self, data): def _handle_stage_stream(self, data):
if self._is_local: if self._is_local:
if self._one_time_auth_enable: if self._ota_enable:
data = self._one_time_auth_chunk_data_gen(data) data = self._ota_chunk_data_gen(data)
data = self._encryptor.encrypt(data) data = self._encryptor.encrypt(data)
self._write_to_sock(data, self._remote_sock) self._write_to_sock(data, self._remote_sock)
else: else:
if self._one_time_auth_enable: if self._ota_enable:
self._one_time_auth_chunk_data(data, self._write_to_sock_remote) self._ota_chunk_data(data, self._write_to_sock_remote)
return return
def _on_local_read(self): def _on_local_read(self):

View file

@ -165,7 +165,9 @@ class UDPRelay(object):
data = encrypt.encrypt_all(self._password, self._method, 0, data) data = encrypt.encrypt_all(self._password, self._method, 0, data)
# decrypt data # decrypt data
if not data: if not data:
logging.debug('UDP handle_server: data is empty after decrypt') logging.debug(
'UDP handle_server: data is empty after decrypt'
)
return return
header_result = parse_header(data) header_result = parse_header(data)
if header_result is None: if header_result is None:
@ -181,9 +183,10 @@ class UDPRelay(object):
if len(data) < header_length + ONETIMEAUTH_BYTES: if len(data) < header_length + ONETIMEAUTH_BYTES:
logging.warn('UDP one time auth header is too short') logging.warn('UDP one time auth header is too short')
return return
if onetimeauth_verify(data[-ONETIMEAUTH_BYTES:], _hash = data[-ONETIMEAUTH_BYTES:]
data[header_length: -ONETIMEAUTH_BYTES], _data = data[header_length: -ONETIMEAUTH_BYTES]
self._encryptor.decipher_iv + self._encryptor.key) is False: _key = self._encryptor.decipher_iv + self._encryptor.key
if onetimeauth_verify(_hash, _data, _key) is False:
logging.warn('UDP one time auth fail') logging.warn('UDP one time auth fail')
return return
self._one_time_authed = True self._one_time_authed = True
@ -274,7 +277,8 @@ class UDPRelay(object):
def _one_time_auth_chunk_data_gen(self, data): def _one_time_auth_chunk_data_gen(self, data):
data = chr(ord(data[0]) | ADDRTYPE_AUTH) + data[1:] data = chr(ord(data[0]) | ADDRTYPE_AUTH) + data[1:]
return data + onetimeauth_gen(data, self._encryptor.cipher_iv + self._encryptor.key) key = self._encryptor.cipher_iv + self._encryptor.key
return data + onetimeauth_gen(data, key)
def add_to_loop(self, loop): def add_to_loop(self, loop):
if self._eventloop: if self._eventloop: