multi-user for mudbjson mode
This commit is contained in:
parent
588d8ad8b3
commit
2de2a553d7
5 changed files with 188 additions and 53 deletions
|
@ -85,6 +85,8 @@ class TransferBase(object):
|
|||
logging.error('load switchrule.py fail')
|
||||
cur_servers = {}
|
||||
new_servers = {}
|
||||
allow_users = {}
|
||||
mu_servers = {}
|
||||
for row in rows:
|
||||
try:
|
||||
allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable']
|
||||
|
@ -113,33 +115,49 @@ class TransferBase(object):
|
|||
logging.error('more than one user use the same port [%s]' % (port,))
|
||||
continue
|
||||
|
||||
if ServerPool.get_instance().server_is_run(port) > 0:
|
||||
if not allow:
|
||||
logging.info('db stop server at port [%s]' % (port,))
|
||||
ServerPool.get_instance().cb_del_server(port)
|
||||
self.force_update_transfer.add(port)
|
||||
else:
|
||||
if allow:
|
||||
allow_users[port] = 1
|
||||
if 'protocol' in cfg and 'protocol_param' in cfg and common.to_str(cfg['protocol']) in ['auth_aes128_md5', 'auth_aes128_sha1']:
|
||||
if '#' in common.to_str(cfg['protocol_param']):
|
||||
mu_servers[port] = 1
|
||||
|
||||
cfgchange = False
|
||||
if port in ServerPool.get_instance().tcp_servers_pool:
|
||||
relay = ServerPool.get_instance().tcp_servers_pool[port]
|
||||
for name in merge_config_keys:
|
||||
if name in cfg and not self.cmp(cfg[name], relay._config[name]):
|
||||
cfgchange = True
|
||||
break;
|
||||
break
|
||||
if not cfgchange and port in ServerPool.get_instance().tcp_ipv6_servers_pool:
|
||||
relay = ServerPool.get_instance().tcp_ipv6_servers_pool[port]
|
||||
for name in merge_config_keys:
|
||||
if name in cfg and not self.cmp(cfg[name], relay._config[name]):
|
||||
cfgchange = True
|
||||
break;
|
||||
#config changed
|
||||
break
|
||||
|
||||
if port in mu_servers:
|
||||
if ServerPool.get_instance().server_is_run(port) > 0:
|
||||
if cfgchange:
|
||||
logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg))
|
||||
ServerPool.get_instance().cb_del_server(port)
|
||||
self.force_update_transfer.add(port)
|
||||
new_servers[port] = (passwd, cfg)
|
||||
else:
|
||||
self.new_server(port, passwd, cfg)
|
||||
else:
|
||||
if ServerPool.get_instance().server_is_run(port) > 0:
|
||||
if not allow:
|
||||
logging.info('db stop server at port [%s]' % (port,))
|
||||
ServerPool.get_instance().cb_del_server(port)
|
||||
self.force_update_transfer.add(port)
|
||||
else:
|
||||
if cfgchange:
|
||||
logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg))
|
||||
ServerPool.get_instance().cb_del_server(port)
|
||||
self.force_update_transfer.add(port)
|
||||
new_servers[port] = (passwd, cfg)
|
||||
|
||||
elif allow and ServerPool.get_instance().server_run_status(port) is False:
|
||||
elif allow and port > 0 and port < 65536 and ServerPool.get_instance().server_run_status(port) is False:
|
||||
self.new_server(port, passwd, cfg)
|
||||
|
||||
for row in last_rows:
|
||||
|
@ -159,6 +177,11 @@ class TransferBase(object):
|
|||
passwd, cfg = new_servers[port]
|
||||
self.new_server(port, passwd, cfg)
|
||||
|
||||
if isinstance(self, MuJsonTransfer): # works in MuJsonTransfer only
|
||||
logging.debug('db allow users %s \nmu_servers %s' % (allow_users, mu_servers))
|
||||
for port in mu_servers:
|
||||
ServerPool.get_instance().update_mu_server(port, None, allow_users)
|
||||
|
||||
def clear_cache(self, port):
|
||||
if port in self.force_update_transfer: del self.force_update_transfer[port]
|
||||
if port in self.last_get_transfer: del self.last_get_transfer[port]
|
||||
|
|
|
@ -73,8 +73,12 @@ class MuMgr(object):
|
|||
|
||||
def userinfo(self, user):
|
||||
ret = ""
|
||||
key_list = ['user', 'port', 'method', 'passwd', 'protocol', 'protocol_param', 'obfs', 'obfs_param', 'transfer_enable', 'u', 'd']
|
||||
for key in sorted(user):
|
||||
if key in ['enable']:
|
||||
if key not in key_list:
|
||||
key_list.append(key)
|
||||
for key in key_list:
|
||||
if key in ['enable'] or key not in user:
|
||||
continue
|
||||
ret += '\n'
|
||||
if key in ['transfer_enable', 'u', 'd']:
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
import os
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
from shadowsocks import shell, eventloop, tcprelay, udprelay, asyncdns, common
|
||||
import threading
|
||||
|
@ -213,23 +214,62 @@ class ServerPool(object):
|
|||
|
||||
return True
|
||||
|
||||
def update_mu_server(self, port, protocol_param, acl):
|
||||
port = int(port)
|
||||
if port in self.tcp_servers_pool:
|
||||
try:
|
||||
self.tcp_servers_pool[port].update_users(protocol_param, acl)
|
||||
except Exception as e:
|
||||
logging.warn(e)
|
||||
try:
|
||||
self.udp_servers_pool[port].update_users(protocol_param, acl)
|
||||
except Exception as e:
|
||||
logging.warn(e)
|
||||
if port in self.tcp_ipv6_servers_pool:
|
||||
try:
|
||||
self.tcp_ipv6_servers_pool[port].update_users(protocol_param, acl)
|
||||
except Exception as e:
|
||||
logging.warn(e)
|
||||
try:
|
||||
self.udp_ipv6_servers_pool[port].update_users(protocol_param, acl)
|
||||
except Exception as e:
|
||||
logging.warn(e)
|
||||
|
||||
def get_server_transfer(self, port):
|
||||
port = int(port)
|
||||
uid = struct.pack('<I', port)
|
||||
ret = [0, 0]
|
||||
if port in self.tcp_servers_pool:
|
||||
ret[0] = self.tcp_servers_pool[port].server_transfer_ul
|
||||
ret[1] = self.tcp_servers_pool[port].server_transfer_dl
|
||||
ret[0], ret[1] = self.tcp_servers_pool[port].get_ud()
|
||||
if port in self.udp_servers_pool:
|
||||
ret[0] += self.udp_servers_pool[port].server_transfer_ul
|
||||
ret[1] += self.udp_servers_pool[port].server_transfer_dl
|
||||
u, d = self.udp_servers_pool[port].get_ud()
|
||||
ret[0] += u
|
||||
ret[1] += d
|
||||
if port in self.tcp_ipv6_servers_pool:
|
||||
ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul
|
||||
ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl
|
||||
u, d = self.tcp_ipv6_servers_pool[port].get_ud()
|
||||
ret[0] += u
|
||||
ret[1] += d
|
||||
if port in self.udp_ipv6_servers_pool:
|
||||
ret[0] += self.udp_ipv6_servers_pool[port].server_transfer_ul
|
||||
ret[1] += self.udp_ipv6_servers_pool[port].server_transfer_dl
|
||||
u, d = self.udp_ipv6_servers_pool[port].get_ud()
|
||||
ret[0] += u
|
||||
ret[1] += d
|
||||
return ret
|
||||
|
||||
def get_server_mu_transfer(self, server):
|
||||
return server.get_users_ud()
|
||||
|
||||
def update_mu_transfer(self, user_dict, u, d):
|
||||
for uid in u:
|
||||
port = struct.unpack('<I', uid)[0]
|
||||
if port not in user_dict:
|
||||
user_dict[port] = [0, 0]
|
||||
user_dict[port][0] += u[uid]
|
||||
for uid in d:
|
||||
port = struct.unpack('<I', uid)[0]
|
||||
if port not in user_dict:
|
||||
user_dict[port] = [0, 0]
|
||||
user_dict[port][1] += d[uid]
|
||||
|
||||
def get_servers_transfer(self):
|
||||
servers = self.tcp_servers_pool.copy()
|
||||
servers.update(self.tcp_ipv6_servers_pool)
|
||||
|
@ -238,5 +278,17 @@ class ServerPool(object):
|
|||
ret = {}
|
||||
for port in servers.keys():
|
||||
ret[port] = self.get_server_transfer(port)
|
||||
for port in self.tcp_servers_pool:
|
||||
u, d = self.get_server_mu_transfer(self.tcp_servers_pool[port])
|
||||
self.update_mu_transfer(ret, u, d)
|
||||
for port in self.tcp_ipv6_servers_pool:
|
||||
u, d = self.get_server_mu_transfer(self.tcp_ipv6_servers_pool[port])
|
||||
self.update_mu_transfer(ret, u, d)
|
||||
for port in self.udp_servers_pool:
|
||||
u, d = self.get_server_mu_transfer(self.udp_servers_pool[port])
|
||||
self.update_mu_transfer(ret, u, d)
|
||||
for port in self.udp_ipv6_servers_pool:
|
||||
u, d = self.get_server_mu_transfer(self.udp_ipv6_servers_pool[port])
|
||||
self.update_mu_transfer(ret, u, d)
|
||||
return ret
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import binascii
|
|||
import traceback
|
||||
import random
|
||||
import platform
|
||||
import threading
|
||||
|
||||
from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache
|
||||
from shadowsocks.common import pre_parse_header, parse_header
|
||||
|
@ -900,6 +901,9 @@ class TCPRelayHandler(object):
|
|||
if self._stage == STAGE_DESTROYED:
|
||||
logging.debug('ignore handle_event: destroyed')
|
||||
return
|
||||
if self._user is not None and self._user not in self._server.server_users:
|
||||
self.destroy()
|
||||
return
|
||||
# order is important
|
||||
if sock == self._remote_sock or sock == self._remote_sock_v6:
|
||||
if event & eventloop.POLL_ERR:
|
||||
|
@ -1000,6 +1004,8 @@ class TCPRelay(object):
|
|||
self.server_users = {}
|
||||
self.server_user_transfer_ul = {}
|
||||
self.server_user_transfer_dl = {}
|
||||
self.update_users_protocol_param = None
|
||||
self.update_users_acl = None
|
||||
self.server_connections = 0
|
||||
self.protocol_data = obfs.obfs(config['protocol']).init_data()
|
||||
self.obfs_data = obfs.obfs(config['obfs']).init_data()
|
||||
|
@ -1020,16 +1026,7 @@ class TCPRelay(object):
|
|||
self._listen_port = listen_port
|
||||
|
||||
if common.to_bytes(config['protocol']) in [b"auth_aes128_md5", b"auth_aes128_sha1"]:
|
||||
param = common.to_bytes(config['protocol_param']).split(b'#')
|
||||
if len(param) == 2:
|
||||
user_list = param[1].split(b',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(b':')
|
||||
if len(items) == 2:
|
||||
uid = struct.pack('<I', int(items[0]))
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
self._update_users(None, None)
|
||||
|
||||
addrs = socket.getaddrinfo(listen_addr, listen_port, 0,
|
||||
socket.SOCK_STREAM, socket.SOL_TCP)
|
||||
|
@ -1070,10 +1067,38 @@ class TCPRelay(object):
|
|||
self.server_connections += val
|
||||
logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,))
|
||||
|
||||
def get_ud(self):
|
||||
return (self.server_transfer_ul, self.server_transfer_dl)
|
||||
|
||||
def get_users_ud(self):
|
||||
return (self.server_user_transfer_ul.copy(), self.server_user_transfer_dl.copy())
|
||||
|
||||
def _update_users(self, protocol_param, acl):
|
||||
if protocol_param is None:
|
||||
protocol_param = self._config['protocol_param']
|
||||
param = common.to_bytes(protocol_param).split(b'#')
|
||||
if len(param) == 2:
|
||||
user_list = param[1].split(b',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(b':')
|
||||
if len(items) == 2:
|
||||
user_int_id = int(items[0])
|
||||
uid = struct.pack('<I', user_int_id)
|
||||
if acl is not None and user_int_id not in acl:
|
||||
self.del_user(uid)
|
||||
else:
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
|
||||
def update_users(self, protocol_param, acl):
|
||||
self.update_users_protocol_param = protocol_param
|
||||
self.update_users_acl = acl
|
||||
|
||||
def add_user(self, user, passwd): # user: binstr[4], passwd: str
|
||||
self.server_users[user] = common.to_bytes(passwd)
|
||||
|
||||
def del_user(self, user, passwd):
|
||||
def del_user(self, user):
|
||||
if user in self.server_users:
|
||||
del self.server_users[user]
|
||||
|
||||
|
@ -1189,6 +1214,10 @@ class TCPRelay(object):
|
|||
logging.info('closed TCP port %d', self._listen_port)
|
||||
for handler in list(self._fd_to_handlers.values()):
|
||||
handler.destroy()
|
||||
elif self.update_users_protocol_param is not None or self.update_users_acl is not None:
|
||||
self._update_users(self.update_users_protocol_param, self.update_users_acl)
|
||||
self.update_users_protocol_param = None
|
||||
self.update_users_acl = None
|
||||
self._sweep_timeout()
|
||||
|
||||
def close(self, next_tick=False):
|
||||
|
|
|
@ -70,6 +70,7 @@ import errno
|
|||
import random
|
||||
import binascii
|
||||
import traceback
|
||||
import threading
|
||||
|
||||
from shadowsocks import encrypt, obfs, eventloop, lru_cache, common, shell
|
||||
from shadowsocks.common import pre_parse_header, parse_header, pack_addr
|
||||
|
@ -900,18 +901,11 @@ class UDPRelay(object):
|
|||
self.server_users = {}
|
||||
self.server_user_transfer_ul = {}
|
||||
self.server_user_transfer_dl = {}
|
||||
self.update_users_protocol_param = None
|
||||
self.update_users_acl = None
|
||||
|
||||
if common.to_bytes(config['protocol']) in [b"auth_aes128_md5", b"auth_aes128_sha1"]:
|
||||
param = common.to_bytes(config['protocol_param']).split(b'#')
|
||||
if len(param) == 2:
|
||||
user_list = param[1].split(b',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(b':')
|
||||
if len(items) == 2:
|
||||
uid = struct.pack('<I', int(items[0]))
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
self._update_users(None, None)
|
||||
|
||||
self.protocol_data = obfs.obfs(config['protocol']).init_data()
|
||||
self._protocol = obfs.obfs(config['protocol'])
|
||||
|
@ -972,10 +966,39 @@ class UDPRelay(object):
|
|||
logging.debug('chosen server: %s:%d', server, server_port)
|
||||
return server, server_port
|
||||
|
||||
def get_ud(self):
|
||||
return (self.server_transfer_ul, self.server_transfer_dl)
|
||||
|
||||
def get_users_ud(self):
|
||||
ret = (self.server_user_transfer_ul.copy(), self.server_user_transfer_dl.copy())
|
||||
return ret
|
||||
|
||||
def _update_users(self, protocol_param, acl):
|
||||
if protocol_param is None:
|
||||
protocol_param = self._config['protocol_param']
|
||||
param = common.to_bytes(protocol_param).split(b'#')
|
||||
if len(param) == 2:
|
||||
user_list = param[1].split(b',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
items = user.split(b':')
|
||||
if len(items) == 2:
|
||||
user_int_id = int(items[0])
|
||||
uid = struct.pack('<I', user_int_id)
|
||||
if acl is not None and user_int_id not in acl:
|
||||
self.del_user(uid)
|
||||
else:
|
||||
passwd = items[1]
|
||||
self.add_user(uid, passwd)
|
||||
|
||||
def update_users(self, protocol_param, acl):
|
||||
self.update_users_protocol_param = protocol_param
|
||||
self.update_users_acl = acl
|
||||
|
||||
def add_user(self, user, passwd): # user: binstr[4], passwd: str
|
||||
self.server_users[user] = common.to_bytes(passwd)
|
||||
|
||||
def del_user(self, user, passwd):
|
||||
def del_user(self, user):
|
||||
if user in self.server_users:
|
||||
del self.server_users[user]
|
||||
|
||||
|
@ -1434,6 +1457,10 @@ class UDPRelay(object):
|
|||
self._dns_cache.sweep()
|
||||
if before_sweep_size != len(self._sockets):
|
||||
logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets)))
|
||||
if self.update_users_protocol_param is not None or self.update_users_acl is not None:
|
||||
self._update_users(self.update_users_protocol_param, self.update_users_acl)
|
||||
self.update_users_protocol_param = None
|
||||
self.update_users_acl = None
|
||||
self._sweep_timeout()
|
||||
|
||||
def close(self, next_tick=False):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue