multiuser in single port protocol
This commit is contained in:
parent
22ce739a45
commit
959aad3f41
5 changed files with 144 additions and 32 deletions
|
@ -1103,8 +1103,8 @@ class auth_aes128(auth_base):
|
|||
length = len(buf)
|
||||
data = buf[:-4]
|
||||
if struct.pack('<I', zlib.adler32(data) & 0xFFFFFFFF) != buf[length - 4:]:
|
||||
return b''
|
||||
return data
|
||||
return (b'', None)
|
||||
return (data, None)
|
||||
|
||||
class auth_aes128_sha1(auth_base):
|
||||
def __init__(self, method, hashfunc):
|
||||
|
@ -1280,9 +1280,15 @@ class auth_aes128_sha1(auth_base):
|
|||
return (b'', False)
|
||||
return self.not_match_return(self.recv_buf)
|
||||
|
||||
user_key = self.recv_buf[7:11]
|
||||
#if user_key in user_map: self.user_key[user_key] else: # TODO
|
||||
self.user_key = self.server_info.key
|
||||
uid = self.recv_buf[7:11]
|
||||
if uid in self.server_info.users:
|
||||
self.user_key = self.server_info.users[uid]
|
||||
self.server_info.update_user_func(uid)
|
||||
else:
|
||||
if not self.server_info.users:
|
||||
self.user_key = self.server_info.key
|
||||
else:
|
||||
self.user_key = self.server_info.recv_iv
|
||||
encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc')
|
||||
head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[11:27] + b'\x00') # need an extra byte or recv empty
|
||||
length = struct.unpack('<H', head[12:14])[0]
|
||||
|
@ -1377,8 +1383,15 @@ class auth_aes128_sha1(auth_base):
|
|||
|
||||
def server_udp_post_decrypt(self, buf):
|
||||
uid = buf[-8:-4]
|
||||
user_key = self.server_info.key
|
||||
if hmac.new(user_key, buf[:-4], self.hashfunc).digest()[:4] != buf[-4:]:
|
||||
return b''
|
||||
return buf[:-8]
|
||||
if uid in self.server_info.users:
|
||||
self.user_key = self.server_info.users[uid]
|
||||
else:
|
||||
uid = None
|
||||
if not self.server_info.users:
|
||||
self.user_key = self.server_info.key
|
||||
else:
|
||||
self.user_key = self.server_info.recv_iv
|
||||
if hmac.new(self.user_key, buf[:-4], self.hashfunc).digest()[:4] != buf[-4:]:
|
||||
return (b'', None)
|
||||
return (buf[:-8], uid)
|
||||
|
||||
|
|
|
@ -40,6 +40,9 @@ class plain(object):
|
|||
def init_data(self):
|
||||
return b''
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_info
|
||||
|
||||
def set_server_info(self, server_info):
|
||||
self.server_info = server_info
|
||||
|
||||
|
@ -79,7 +82,7 @@ class plain(object):
|
|||
return buf
|
||||
|
||||
def server_udp_post_decrypt(self, buf):
|
||||
return buf
|
||||
return (buf, None)
|
||||
|
||||
def dispose(self):
|
||||
pass
|
||||
|
|
|
@ -350,11 +350,11 @@ class verify_sha1(verify_base):
|
|||
def server_udp_post_decrypt(self, buf):
|
||||
if buf and ((ord(buf[0]) & 0x10) == 0x10):
|
||||
if len(buf) <= 11:
|
||||
return b''
|
||||
return (b'', None)
|
||||
sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, buf[:-10], hashlib.sha1).digest()[:10]
|
||||
if sha1data != buf[-10:]:
|
||||
return b''
|
||||
return to_bytes(chr(ord(buf[0]) & 0xEF)) + buf[1:-10]
|
||||
return (b'', None)
|
||||
return (to_bytes(chr(ord(buf[0]) & 0xEF)) + buf[1:-10], None)
|
||||
else:
|
||||
return buf
|
||||
return (buf, None)
|
||||
|
||||
|
|
|
@ -106,6 +106,7 @@ class TCPRelayHandler(object):
|
|||
self._dns_resolver = dns_resolver
|
||||
self._client_address = local_sock.getpeername()[:2]
|
||||
self._accept_address = local_sock.getsockname()[:2]
|
||||
self._user = None
|
||||
|
||||
# TCP Relay works as either sslocal or ssserver
|
||||
# if is_local, this is sslocal
|
||||
|
@ -123,6 +124,8 @@ class TCPRelayHandler(object):
|
|||
server_info = obfs.server_info(server.obfs_data)
|
||||
server_info.host = config['server']
|
||||
server_info.port = server._listen_port
|
||||
#server_info.users = server.server_users
|
||||
#server_info.update_user_func = self._update_user
|
||||
server_info.client = self._client_address[0]
|
||||
server_info.client_port = self._client_address[1]
|
||||
server_info.protocol_param = ''
|
||||
|
@ -139,6 +142,8 @@ class TCPRelayHandler(object):
|
|||
server_info = obfs.server_info(server.protocol_data)
|
||||
server_info.host = config['server']
|
||||
server_info.port = server._listen_port
|
||||
server_info.users = server.server_users
|
||||
server_info.update_user_func = self._update_user
|
||||
server_info.client = self._client_address[0]
|
||||
server_info.client_port = self._client_address[1]
|
||||
server_info.protocol_param = config['protocol_param']
|
||||
|
@ -203,6 +208,9 @@ class TCPRelayHandler(object):
|
|||
logging.debug('chosen server: %s:%d', server, server_port)
|
||||
return server, server_port
|
||||
|
||||
def _update_user(self, user):
|
||||
self._user = user
|
||||
|
||||
def _update_activity(self, data_len=0):
|
||||
# tell the TCP Relay we have activities recently
|
||||
# else it will think we are inactive and timed out
|
||||
|
@ -303,7 +311,7 @@ class TCPRelayHandler(object):
|
|||
try:
|
||||
if self._encrypt_correct:
|
||||
if sock == self._remote_sock:
|
||||
self._server.server_transfer_ul += len(data)
|
||||
self._server.add_transfer_u(self._user, len(data))
|
||||
self._update_activity(len(data))
|
||||
if data:
|
||||
l = len(data)
|
||||
|
@ -839,7 +847,7 @@ class TCPRelayHandler(object):
|
|||
data = self._encryptor.encrypt(data)
|
||||
data = self._obfs.server_encode(data)
|
||||
self._update_activity(len(data))
|
||||
self._server.server_transfer_dl += len(data)
|
||||
self._server.add_transfer_d(self._user, len(data))
|
||||
else:
|
||||
return
|
||||
try:
|
||||
|
@ -989,6 +997,9 @@ class TCPRelay(object):
|
|||
self._fd_to_handlers = {}
|
||||
self.server_transfer_ul = 0
|
||||
self.server_transfer_dl = 0
|
||||
self.server_users = {}
|
||||
self.server_user_transfer_ul = {}
|
||||
self.server_user_transfer_dl = {}
|
||||
self.server_connections = 0
|
||||
self.protocol_data = obfs.obfs(config['protocol']).init_data()
|
||||
self.obfs_data = obfs.obfs(config['obfs']).init_data()
|
||||
|
@ -1008,6 +1019,16 @@ class TCPRelay(object):
|
|||
listen_port = config['server_port']
|
||||
self._listen_port = listen_port
|
||||
|
||||
if config['protocol'] in ["auth_aes128_md5", "auth_aes128_sha1"]:
|
||||
user_list = config['protocol_param'].split(',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(':')
|
||||
if len(items) == 2:
|
||||
uid = struct.pack('<I', int(items[0]))
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
|
||||
addrs = socket.getaddrinfo(listen_addr, listen_port, 0,
|
||||
socket.SOCK_STREAM, socket.SOL_TCP)
|
||||
if len(addrs) == 0:
|
||||
|
@ -1047,6 +1068,29 @@ class TCPRelay(object):
|
|||
self.server_connections += val
|
||||
logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,))
|
||||
|
||||
def add_user(self, user, passwd): # user: binstr[4], passwd: str
|
||||
self.server_users[user] = common.to_bytes(passwd)
|
||||
|
||||
def del_user(self, user, passwd):
|
||||
if user in self.server_users:
|
||||
del self.server_users[user]
|
||||
|
||||
def add_transfer_u(self, user, transfer):
|
||||
if user is None:
|
||||
self.server_transfer_ul += transfer
|
||||
else:
|
||||
if user not in self.server_user_transfer_ul:
|
||||
self.server_user_transfer_ul[user] = 0
|
||||
self.server_user_transfer_ul[user] += transfer
|
||||
|
||||
def add_transfer_d(self, user, transfer):
|
||||
if user is None:
|
||||
self.server_transfer_dl += transfer
|
||||
else:
|
||||
if user not in self.server_user_transfer_dl:
|
||||
self.server_user_transfer_dl[user] = 0
|
||||
self.server_user_transfer_dl[user] += transfer
|
||||
|
||||
def update_stat(self, port, stat_dict, val):
|
||||
newval = stat_dict.get(0, 0) + val
|
||||
stat_dict[0] = newval
|
||||
|
|
|
@ -888,21 +888,35 @@ class UDPRelay(object):
|
|||
self._is_local = is_local
|
||||
self._udp_cache_size = config['udp_cache']
|
||||
self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'],
|
||||
close_callback=self._close_client)
|
||||
close_callback=self._close_client_pair)
|
||||
self._cache_dns_client = lru_cache.LRUCache(timeout=10,
|
||||
close_callback=self._close_client)
|
||||
close_callback=self._close_client_pair)
|
||||
self._client_fd_to_server_addr = {}
|
||||
self._dns_cache = lru_cache.LRUCache(timeout=300)
|
||||
self._eventloop = None
|
||||
self._closed = False
|
||||
self.server_transfer_ul = 0
|
||||
self.server_transfer_dl = 0
|
||||
self.server_users = {}
|
||||
self.server_user_transfer_ul = {}
|
||||
self.server_user_transfer_dl = {}
|
||||
|
||||
if config['protocol'] in ["auth_aes128_md5", "auth_aes128_sha1"]:
|
||||
user_list = config['protocol_param'].split(',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(':')
|
||||
if len(items) == 2:
|
||||
uid = struct.pack('<I', int(items[0]))
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
|
||||
self.protocol_data = obfs.obfs(config['protocol']).init_data()
|
||||
self._protocol = obfs.obfs(config['protocol'])
|
||||
server_info = obfs.server_info(self.protocol_data)
|
||||
server_info.host = self._listen_addr
|
||||
server_info.port = self._listen_port
|
||||
server_info.users = self.server_users
|
||||
server_info.protocol_param = config['protocol_param']
|
||||
server_info.obfs_param = ''
|
||||
server_info.iv = b''
|
||||
|
@ -956,6 +970,33 @@ class UDPRelay(object):
|
|||
logging.debug('chosen server: %s:%d', server, server_port)
|
||||
return server, server_port
|
||||
|
||||
def add_user(self, user, passwd): # user: binstr[4], passwd: str
|
||||
self.server_users[user] = common.to_bytes(passwd)
|
||||
|
||||
def del_user(self, user, passwd):
|
||||
if user in self.server_users:
|
||||
del self.server_users[user]
|
||||
|
||||
def add_transfer_u(self, user, transfer):
|
||||
if user is None:
|
||||
self.server_transfer_ul += transfer
|
||||
else:
|
||||
if user not in self.server_user_transfer_ul:
|
||||
self.server_user_transfer_ul[user] = 0
|
||||
self.server_user_transfer_ul[user] += transfer
|
||||
|
||||
def add_transfer_d(self, user, transfer):
|
||||
if user is None:
|
||||
self.server_transfer_dl += transfer
|
||||
else:
|
||||
if user not in self.server_user_transfer_dl:
|
||||
self.server_user_transfer_dl[user] = 0
|
||||
self.server_user_transfer_dl[user] += transfer
|
||||
|
||||
def _close_client_pair(self, client_pair):
|
||||
client, uid = client_pair
|
||||
self._close_client(client)
|
||||
|
||||
def _close_client(self, client):
|
||||
if hasattr(client, 'close'):
|
||||
if not self._is_local:
|
||||
|
@ -1039,6 +1080,7 @@ class UDPRelay(object):
|
|||
logging.debug('UDP handle_server: data is empty')
|
||||
if self._stat_callback:
|
||||
self._stat_callback(self._listen_port, len(data))
|
||||
uid = None
|
||||
if self._is_local:
|
||||
frag = common.ord(data[2])
|
||||
if frag != 0:
|
||||
|
@ -1054,7 +1096,7 @@ class UDPRelay(object):
|
|||
logging.debug('UDP handle_server: data is empty after decrypt')
|
||||
return
|
||||
self._protocol.obfs.server_info.recv_iv = ref_iv[0]
|
||||
data = self._protocol.server_udp_post_decrypt(data)
|
||||
data, uid = self._protocol.server_udp_post_decrypt(data)
|
||||
|
||||
#logging.info("UDP data %s" % (binascii.hexlify(data),))
|
||||
if not self._is_local:
|
||||
|
@ -1097,10 +1139,10 @@ class UDPRelay(object):
|
|||
|
||||
af, socktype, proto, canonname, sa = addrs[0]
|
||||
key = client_key(r_addr, af)
|
||||
client = self._cache.get(key, None)
|
||||
if not client:
|
||||
client = self._cache_dns_client.get(key, None)
|
||||
if not client:
|
||||
client_pair = self._cache.get(key, None)
|
||||
if not client_pair:
|
||||
client_pair = self._cache_dns_client.get(key, None)
|
||||
if not client_pair:
|
||||
if self._forbidden_iplist:
|
||||
if common.to_str(sa[0]) in self._forbidden_iplist:
|
||||
logging.debug('IP %s is in forbidden list, drop' %
|
||||
|
@ -1114,6 +1156,7 @@ class UDPRelay(object):
|
|||
# drop
|
||||
return
|
||||
client = socket.socket(af, socktype, proto)
|
||||
client_uid = uid
|
||||
client.setblocking(False)
|
||||
self._socket_bind_addr(client, af)
|
||||
is_dns = False
|
||||
|
@ -1124,9 +1167,9 @@ class UDPRelay(object):
|
|||
#logging.info("unknown data %s" % (binascii.hexlify(data),))
|
||||
if sa[1] == 53 and is_dns: #DNS
|
||||
logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1]))
|
||||
self._cache_dns_client[key] = client
|
||||
self._cache_dns_client[key] = (client, uid)
|
||||
else:
|
||||
self._cache[key] = client
|
||||
self._cache[key] = (client, uid)
|
||||
self._client_fd_to_server_addr[client.fileno()] = (r_addr, af)
|
||||
|
||||
self._sockets.add(client.fileno())
|
||||
|
@ -1137,7 +1180,8 @@ class UDPRelay(object):
|
|||
common.connect_log('UDP data to %s:%d via port %d' %
|
||||
(common.to_str(server_addr), server_port,
|
||||
self._listen_port))
|
||||
|
||||
else:
|
||||
client, client_uid = client_pair
|
||||
self._cache.clear(self._udp_cache_size)
|
||||
self._cache_dns_client.clear(16)
|
||||
|
||||
|
@ -1156,7 +1200,7 @@ class UDPRelay(object):
|
|||
try:
|
||||
#logging.info('UDP handle_server sendto %s:%d %d bytes' % (common.to_str(server_addr), server_port, len(data)))
|
||||
client.sendto(data, (server_addr, server_port))
|
||||
self.server_transfer_ul += len(data)
|
||||
self.add_transfer_u(client_uid, len(data))
|
||||
except IOError as e:
|
||||
err = eventloop.errno_from_exception(e)
|
||||
if err in (errno.EINPROGRESS, errno.EAGAIN):
|
||||
|
@ -1266,14 +1310,22 @@ class UDPRelay(object):
|
|||
response = b'\x00\x00\x00' + data
|
||||
client_addr = self._client_fd_to_server_addr.get(sock.fileno())
|
||||
if client_addr:
|
||||
self.server_transfer_dl += len(response)
|
||||
self.write_to_server_socket(response, client_addr[0])
|
||||
key = client_key(client_addr[0], client_addr[1])
|
||||
client = self._cache_dns_client.get(key, None)
|
||||
if client:
|
||||
client_pair = self._cache.get(key, None)
|
||||
client_dns_pair = self._cache_dns_client.get(key, None)
|
||||
if client_pair:
|
||||
client, client_uid = client_pair
|
||||
self.add_transfer_d(client_uid, len(response))
|
||||
elif client_dns_pair:
|
||||
client, client_uid = client_dns_pair
|
||||
self.add_transfer_d(client_uid, len(response))
|
||||
else:
|
||||
self.server_transfer_dl += len(response)
|
||||
self.write_to_server_socket(response, client_addr[0])
|
||||
if client_dns_pair:
|
||||
logging.debug("remove dns client %s:%d" % (client_addr[0][0], client_addr[0][1]))
|
||||
del self._cache_dns_client[key]
|
||||
self._close_client(client)
|
||||
self._close_client(client_dns_pair[0])
|
||||
else:
|
||||
# this packet is from somewhere else we know
|
||||
# simply drop that packet
|
||||
|
|
Loading…
Add table
Reference in a new issue