impl auth_chain client
This commit is contained in:
parent
774134ffad
commit
caadb606ba
3 changed files with 72 additions and 57 deletions
|
@ -336,34 +336,35 @@ class auth_chain_a(auth_base):
|
|||
return data
|
||||
|
||||
def pack_auth_data(self, auth_data, buf):
|
||||
if len(buf) == 0:
|
||||
return b''
|
||||
if len(buf) > 400:
|
||||
rnd_len = struct.unpack('<H', os.urandom(2))[0] % 512
|
||||
else:
|
||||
rnd_len = struct.unpack('<H', os.urandom(2))[0] % 1024
|
||||
data = auth_data
|
||||
data_len = 7 + 4 + 16 + 4 + len(buf) + rnd_len + 4
|
||||
data = data + struct.pack('<H', data_len) + struct.pack('<H', rnd_len)
|
||||
data_len = 12 + 4 + 16 + 4
|
||||
data = data + (struct.pack('<H', self.server_info.overhead) + struct.pack('<H', 0))
|
||||
mac_key = self.server_info.iv + self.server_info.key
|
||||
uid = os.urandom(4)
|
||||
|
||||
check_head = os.urandom(4)
|
||||
self.last_client_hash = hmac.new(mac_key, check_head, self.hashfunc).digest()
|
||||
check_head += self.last_client_hash[:8]
|
||||
|
||||
if b':' in to_bytes(self.server_info.protocol_param):
|
||||
try:
|
||||
items = to_bytes(self.server_info.protocol_param).split(b':')
|
||||
self.user_key = self.hashfunc(items[1]).digest()
|
||||
self.user_key = items[1]
|
||||
uid = struct.pack('<I', int(items[0]))
|
||||
except:
|
||||
pass
|
||||
uid = os.urandom(4)
|
||||
else:
|
||||
uid = os.urandom(4)
|
||||
if self.user_key is None:
|
||||
self.user_key = self.server_info.key
|
||||
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
|
||||
data = uid + encryptor.encrypt(data)[16:]
|
||||
data += hmac.new(mac_key, data, self.hashfunc).digest()[:4]
|
||||
check_head = os.urandom(1)
|
||||
check_head += hmac.new(mac_key, check_head, self.hashfunc).digest()[:6]
|
||||
data = check_head + data + os.urandom(rnd_len) + buf
|
||||
data += hmac.new(self.user_key, data, self.hashfunc).digest()[:4]
|
||||
return data
|
||||
|
||||
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key))+ to_bytes(base64.b64encode(self.last_client_hash)) + self.salt, 'aes-128-cbc', b'\x00' * 16)
|
||||
|
||||
uid = struct.unpack('<I', uid)[0] ^ struct.unpack('<I', self.last_client_hash[8:12])[0]
|
||||
uid = struct.pack('<I', uid)
|
||||
data = check_head + uid + encryptor.encrypt(data)[16:]
|
||||
self.last_server_hash = hmac.new(mac_key, data, self.hashfunc).digest()
|
||||
data += self.last_server_hash[:4]
|
||||
return data + self.pack_client_data(buf)
|
||||
|
||||
def auth_data(self):
|
||||
utc_time = int(time.time()) & 0xFFFFFFFF
|
||||
|
@ -400,30 +401,34 @@ class auth_chain_a(auth_base):
|
|||
out_buf = b''
|
||||
while len(self.recv_buf) > 4:
|
||||
mac_key = self.user_key + struct.pack('<I', self.recv_id)
|
||||
mac = hmac.new(mac_key, self.recv_buf[:2], self.hashfunc).digest()[:2]
|
||||
if mac != self.recv_buf[2:4]:
|
||||
raise Exception('client_post_decrypt data uncorrect mac')
|
||||
length = struct.unpack('<H', self.recv_buf[:2])[0]
|
||||
if length >= 8192 or length < 7:
|
||||
data_len = struct.unpack('<H', self.recv_buf[:2])[0] ^ struct.unpack('<H', self.last_server_hash[14:16])[0]
|
||||
rand_len = self.rnd_data_len(data_len, self.last_server_hash, self.random_server)
|
||||
length = data_len + rand_len
|
||||
if length >= 4096:
|
||||
self.raw_trans = True
|
||||
self.recv_buf = b''
|
||||
raise Exception('client_post_decrypt data error')
|
||||
if length > len(self.recv_buf):
|
||||
|
||||
if length + 4 > len(self.recv_buf):
|
||||
break
|
||||
|
||||
if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]:
|
||||
server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest()
|
||||
if server_hash[:2] != self.recv_buf[length + 2 : length + 4]:
|
||||
logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length])))
|
||||
self.raw_trans = True
|
||||
self.recv_buf = b''
|
||||
raise Exception('client_post_decrypt data uncorrect checksum')
|
||||
|
||||
pos = 2
|
||||
if data_len > 0 and rand_len > 0:
|
||||
pos = 2 + self.rnd_start_pos(rand_len, self.random_server)
|
||||
out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos])
|
||||
self.last_server_hash = server_hash
|
||||
if self.recv_id == 1:
|
||||
self.server_info.tcp_mss = out_buf[:2]
|
||||
out_buf = out_buf[2:]
|
||||
self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF
|
||||
pos = common.ord(self.recv_buf[4])
|
||||
if pos < 255:
|
||||
pos += 4
|
||||
else:
|
||||
pos = struct.unpack('<H', self.recv_buf[5:7])[0] + 4
|
||||
out_buf += self.recv_buf[pos:length - 4]
|
||||
self.recv_buf = self.recv_buf[length:]
|
||||
self.recv_buf = self.recv_buf[length + 4:]
|
||||
|
||||
return out_buf
|
||||
|
||||
|
|
|
@ -135,35 +135,25 @@ class TCPRelayHandler(object):
|
|||
self._remote_udp = False
|
||||
self._config = config
|
||||
self._dns_resolver = dns_resolver
|
||||
if not self._create_encryptor(config):
|
||||
return
|
||||
|
||||
self._client_address = local_sock.getpeername()[:2]
|
||||
self._accept_address = local_sock.getsockname()[:2]
|
||||
self._user = None
|
||||
self._user_id = server._listen_port
|
||||
self._tcp_mss = TCP_MSS
|
||||
self._update_tcp_mss(local_sock)
|
||||
|
||||
# TCP Relay works as either sslocal or ssserver
|
||||
# if is_local, this is sslocal
|
||||
self._is_local = is_local
|
||||
self._stage = STAGE_INIT
|
||||
try:
|
||||
self._encryptor = encrypt.Encryptor(config['password'],
|
||||
config['method'])
|
||||
except Exception:
|
||||
self._stage = STAGE_DESTROYED
|
||||
logging.error('create encryptor fail at port %d', server._listen_port)
|
||||
return
|
||||
self._encrypt_correct = True
|
||||
self._obfs = obfs.obfs(config['obfs'])
|
||||
self._protocol = obfs.obfs(config['protocol'])
|
||||
self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local)
|
||||
self._recv_buffer_size = BUF_SIZE - self._overhead
|
||||
|
||||
try:
|
||||
self._tcp_mss = local_sock.getsockopt(socket.SOL_TCP, socket.TCP_MAXSEG)
|
||||
logging.debug("TCP MSS = %d" % (self._tcp_mss,))
|
||||
except:
|
||||
pass
|
||||
|
||||
server_info = obfs.server_info(server.obfs_data)
|
||||
server_info.host = config['server']
|
||||
server_info.port = server._listen_port
|
||||
|
@ -180,6 +170,7 @@ class TCPRelayHandler(object):
|
|||
server_info.head_len = 30
|
||||
server_info.tcp_mss = self._tcp_mss
|
||||
server_info.buffer_size = self._recv_buffer_size
|
||||
server_info.overhead = self._overhead
|
||||
self._obfs.set_server_info(server_info)
|
||||
|
||||
server_info = obfs.server_info(server.protocol_data)
|
||||
|
@ -198,6 +189,7 @@ class TCPRelayHandler(object):
|
|||
server_info.head_len = 30
|
||||
server_info.tcp_mss = self._tcp_mss
|
||||
server_info.buffer_size = self._recv_buffer_size
|
||||
server_info.overhead = self._overhead
|
||||
self._protocol.set_server_info(server_info)
|
||||
|
||||
self._redir_list = config.get('redirect', ["*#0.0.0.0:0"])
|
||||
|
@ -213,27 +205,24 @@ class TCPRelayHandler(object):
|
|||
self._upstream_status = WAIT_STATUS_READING
|
||||
self._downstream_status = WAIT_STATUS_INIT
|
||||
self._remote_address = None
|
||||
if 'forbidden_ip' in config:
|
||||
self._forbidden_iplist = config['forbidden_ip']
|
||||
else:
|
||||
self._forbidden_iplist = None
|
||||
if 'forbidden_port' in config:
|
||||
self._forbidden_portset = config['forbidden_port']
|
||||
else:
|
||||
self._forbidden_portset = None
|
||||
|
||||
self._forbidden_iplist = config.get('forbidden_ip', None)
|
||||
self._forbidden_portset = config.get('forbidden_port', None)
|
||||
if is_local:
|
||||
self._chosen_server = self._get_a_server()
|
||||
|
||||
fd_to_handlers[local_sock.fileno()] = self
|
||||
local_sock.setblocking(False)
|
||||
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR,
|
||||
self._server)
|
||||
loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server)
|
||||
|
||||
self.last_activity = 0
|
||||
self._update_activity()
|
||||
self._server.add_connection(1)
|
||||
self._server.stat_add(self._client_address[0], 1)
|
||||
self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0))
|
||||
self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0))
|
||||
self._recv_pack_id = 0
|
||||
|
||||
def __hash__(self):
|
||||
# default __hash__ is id / 16
|
||||
|
@ -254,6 +243,23 @@ class TCPRelayHandler(object):
|
|||
logging.debug('chosen server: %s:%d', server, server_port)
|
||||
return server, server_port
|
||||
|
||||
def _update_tcp_mss(self, local_sock):
|
||||
self._tcp_mss = TCP_MSS
|
||||
try:
|
||||
self._tcp_mss = local_sock.getsockopt(socket.SOL_TCP, socket.TCP_MAXSEG)
|
||||
logging.debug("TCP MSS = %d" % (self._tcp_mss,))
|
||||
except:
|
||||
pass
|
||||
|
||||
def _create_encryptor(self, config):
|
||||
try:
|
||||
self._encryptor = encrypt.Encryptor(config['password'],
|
||||
config['method'])
|
||||
return True
|
||||
except Exception:
|
||||
self._stage = STAGE_DESTROYED
|
||||
logging.error('create encryptor fail at port %d', self._server._listen_port)
|
||||
|
||||
def _update_user(self, user):
|
||||
self._user = user
|
||||
self._user_id = struct.unpack('<I', user)[0]
|
||||
|
@ -884,6 +890,7 @@ class TCPRelayHandler(object):
|
|||
else:
|
||||
recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size)
|
||||
data = self._remote_sock.recv(recv_buffer_size)
|
||||
self._recv_pack_id += 1
|
||||
except (OSError, IOError) as e:
|
||||
if eventloop.errno_from_exception(e) in \
|
||||
(errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK
|
||||
|
@ -912,6 +919,8 @@ class TCPRelayHandler(object):
|
|||
data = self._encryptor.decrypt(obfs_decode[0])
|
||||
try:
|
||||
data = self._protocol.client_post_decrypt(data)
|
||||
if self._recv_pack_id == 1:
|
||||
self._tcp_mss = self._protocol.get_server_info().tcp_mss
|
||||
except Exception as e:
|
||||
shell.print_exception(e)
|
||||
logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1]))
|
||||
|
|
|
@ -181,6 +181,7 @@ class UDPRelay(object):
|
|||
server_info.head_len = 30
|
||||
server_info.tcp_mss = 1452
|
||||
server_info.buffer_size = BUF_SIZE
|
||||
server_info.overhead = 0
|
||||
self._protocol.set_server_info(server_info)
|
||||
|
||||
self._sockets = set()
|
||||
|
|
Loading…
Add table
Reference in a new issue