shadowsocks/db_transfer.py
破娃酱 d968f01245 parse comment
var 'rows' not init before return
2017-05-04 10:56:14 +08:00

629 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
import logging
import time
import sys
from server_pool import ServerPool
import traceback
from shadowsocks import common, shell, lru_cache, obfs
from configloader import load_config, get_config
import importloader
switchrule = None
db_instance = None
class TransferBase(object):
def __init__(self):
import threading
self.event = threading.Event()
self.key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable']
self.last_get_transfer = {} #上一次的实际流量
self.last_update_transfer = {} #上一次更新到的流量(小于等于实际流量)
self.force_update_transfer = set() #强制推入数据库的ID
self.port_uid_table = {} #端口到uid的映射仅v3以上有用
self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) #用户在线状态记录
self.pull_ok = False #记录是否已经拉出过数据
self.mu_ports = {}
def load_cfg(self):
pass
def push_db_all_user(self):
if self.pull_ok is False:
return
#更新用户流量到数据库
last_transfer = self.last_update_transfer
curr_transfer = ServerPool.get_instance().get_servers_transfer()
#上次和本次的增量
dt_transfer = {}
for id in self.force_update_transfer: #此表中的用户统计上次未计入的流量
if id in self.last_get_transfer and id in last_transfer:
dt_transfer[id] = [self.last_get_transfer[id][0] - last_transfer[id][0], self.last_get_transfer[id][1] - last_transfer[id][1]]
for id in curr_transfer.keys():
if id in self.force_update_transfer or id in self.mu_ports:
continue
#算出与上次记录的流量差值保存于dt_transfer表
if id in last_transfer:
if curr_transfer[id][0] + curr_transfer[id][1] - last_transfer[id][0] - last_transfer[id][1] <= 0:
continue
dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0],
curr_transfer[id][1] - last_transfer[id][1]]
else:
if curr_transfer[id][0] + curr_transfer[id][1] <= 0:
continue
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
#有流量的,先记录在线状态
if id in self.last_get_transfer:
if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]:
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
else:
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
self.onlineuser_cache.sweep()
update_transfer = self.update_all_user(dt_transfer) #返回有更新的表
for id in update_transfer.keys(): #其增量加在此表
if id not in self.force_update_transfer: #但排除在force_update_transfer内的
last = self.last_update_transfer.get(id, [0,0])
self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]]
self.last_get_transfer = curr_transfer
for id in self.force_update_transfer:
if id in self.last_update_transfer:
del self.last_update_transfer[id]
if id in self.last_get_transfer:
del self.last_get_transfer[id]
self.force_update_transfer = set()
def del_server_out_of_bound_safe(self, last_rows, rows):
#停止超流量的服务
#启动没超流量的服务
try:
switchrule = importloader.load('switchrule')
except Exception as e:
logging.error('load switchrule.py fail')
cur_servers = {}
new_servers = {}
allow_users = {}
mu_servers = {}
config = shell.get_config(False)
for row in rows:
try:
allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable']
except Exception as e:
allow = False
port = row['port']
passwd = common.to_bytes(row['passwd'])
if hasattr(passwd, 'encode'):
passwd = passwd.encode('utf-8')
cfg = {'password': passwd}
if 'id' in row:
self.port_uid_table[row['port']] = row['id']
read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user']
for name in read_config_keys:
if name in row and row[name]:
cfg[name] = row[name]
merge_config_keys = ['password'] + read_config_keys
for name in cfg.keys():
if hasattr(cfg[name], 'encode'):
try:
cfg[name] = cfg[name].encode('utf-8')
except Exception as e:
logging.warning('encode cfg key "%s" fail, val "%s"' % (name, cfg[name]))
if port not in cur_servers:
cur_servers[port] = passwd
else:
logging.error('more than one user use the same port [%s]' % (port,))
continue
if allow:
allow_users[port] = passwd
if 'protocol' in cfg and 'protocol_param' in cfg and common.to_str(cfg['protocol']) in obfs.mu_protocol():
if '#' in common.to_str(cfg['protocol_param']):
mu_servers[port] = passwd
del allow_users[port]
cfgchange = False
if port in ServerPool.get_instance().tcp_servers_pool:
relay = ServerPool.get_instance().tcp_servers_pool[port]
for name in merge_config_keys:
if name in cfg and not self.cmp(cfg[name], relay._config[name]):
cfgchange = True
break
if not cfgchange and port in ServerPool.get_instance().tcp_ipv6_servers_pool:
relay = ServerPool.get_instance().tcp_ipv6_servers_pool[port]
for name in merge_config_keys:
if name in cfg and not self.cmp(cfg[name], relay._config[name]):
cfgchange = True
break
if port in mu_servers:
if ServerPool.get_instance().server_is_run(port) > 0:
if cfgchange:
logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg))
ServerPool.get_instance().cb_del_server(port)
self.force_update_transfer.add(port)
new_servers[port] = (passwd, cfg)
else:
self.new_server(port, passwd, cfg)
else:
if ServerPool.get_instance().server_is_run(port) > 0:
if config['additional_ports_only'] or not allow:
logging.info('db stop server at port [%s]' % (port,))
ServerPool.get_instance().cb_del_server(port)
self.force_update_transfer.add(port)
else:
if cfgchange:
logging.info('db stop server at port [%s] reason: config changed: %s' % (port, cfg))
ServerPool.get_instance().cb_del_server(port)
self.force_update_transfer.add(port)
new_servers[port] = (passwd, cfg)
elif not config['additional_ports_only'] and allow and port > 0 and port < 65536 and ServerPool.get_instance().server_run_status(port) is False:
self.new_server(port, passwd, cfg)
for row in last_rows:
if row['port'] in cur_servers:
pass
else:
logging.info('db stop server at port [%s] reason: port not exist' % (row['port']))
ServerPool.get_instance().cb_del_server(row['port'])
self.clear_cache(row['port'])
if row['port'] in self.port_uid_table:
del self.port_uid_table[row['port']]
if len(new_servers) > 0:
from shadowsocks import eventloop
self.event.wait(eventloop.TIMEOUT_PRECISION + eventloop.TIMEOUT_PRECISION / 2)
for port in new_servers.keys():
passwd, cfg = new_servers[port]
self.new_server(port, passwd, cfg)
logging.debug('db allow users %s \nmu_servers %s' % (allow_users, mu_servers))
for port in mu_servers:
ServerPool.get_instance().update_mu_users(port, allow_users)
self.mu_ports = mu_servers
def clear_cache(self, port):
if port in self.force_update_transfer: del self.force_update_transfer[port]
if port in self.last_get_transfer: del self.last_get_transfer[port]
if port in self.last_update_transfer: del self.last_update_transfer[port]
def new_server(self, port, passwd, cfg):
protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin'))
method = cfg.get('method', ServerPool.get_instance().config.get('method', 'None'))
obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain'))
logging.info('db start server at port [%s] pass [%s] protocol [%s] method [%s] obfs [%s]' % (port, passwd, protocol, method, obfs))
ServerPool.get_instance().new_server(port, cfg)
def cmp(self, val1, val2):
if type(val1) is bytes:
val1 = common.to_str(val1)
if type(val2) is bytes:
val2 = common.to_str(val2)
return val1 == val2
@staticmethod
def del_servers():
for port in [v for v in ServerPool.get_instance().tcp_servers_pool.keys()]:
if ServerPool.get_instance().server_is_run(port) > 0:
ServerPool.get_instance().cb_del_server(port)
for port in [v for v in ServerPool.get_instance().tcp_ipv6_servers_pool.keys()]:
if ServerPool.get_instance().server_is_run(port) > 0:
ServerPool.get_instance().cb_del_server(port)
@staticmethod
def thread_db(obj):
import socket
import time
global db_instance
timeout = 60
socket.setdefaulttimeout(timeout)
last_rows = []
db_instance = obj()
ServerPool.get_instance()
shell.log_shadowsocks_version()
try:
import resource
logging.info('current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE))
except:
pass
try:
while True:
load_config()
db_instance.load_cfg()
try:
db_instance.push_db_all_user()
rows = db_instance.pull_db_all_user()
if rows:
db_instance.pull_ok = True
config = shell.get_config(False)
for port in config['additional_ports']:
val = config['additional_ports'][port]
val['port'] = int(port)
val['enable'] = 1
val['transfer_enable'] = 1024 ** 7
val['u'] = 0
val['d'] = 0
if "password" in val:
val["passwd"] = val["password"]
rows.append(val)
db_instance.del_server_out_of_bound_safe(last_rows, rows)
last_rows = rows
except Exception as e:
trace = traceback.format_exc()
logging.error(trace)
#logging.warn('db thread except:%s' % e)
if db_instance.event.wait(get_config().UPDATE_TIME) or not ServerPool.get_instance().thread.is_alive():
break
except KeyboardInterrupt as e:
pass
db_instance.del_servers()
ServerPool.get_instance().stop()
db_instance = None
@staticmethod
def thread_db_stop():
global db_instance
db_instance.event.set()
class DbTransfer(TransferBase):
def __init__(self):
super(DbTransfer, self).__init__()
self.user_pass = {} #记录更新此用户流量时被跳过多少次
self.cfg = {
"host": "127.0.0.1",
"port": 3306,
"user": "ss",
"password": "pass",
"db": "shadowsocks",
"node_id": 0,
"transfer_mul": 1.0,
"ssl_enable": 0,
"ssl_ca": "",
"ssl_cert": "",
"ssl_key": ""}
self.load_cfg()
def load_cfg(self):
import json
config_path = get_config().MYSQL_CONFIG
cfg = None
with open(config_path, 'rb+') as f:
cfg = json.loads(f.read().decode('utf8'))
if cfg:
self.cfg.update(cfg)
def update_all_user(self, dt_transfer):
import cymysql
update_transfer = {}
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
for id in dt_transfer.keys():
transfer = dt_transfer[id]
#小于最低更新流量的先不更新
update_trs = 1024 * (2048 - self.user_pass.get(id, 0) * 64)
if transfer[0] + transfer[1] < update_trs and id not in self.force_update_transfer:
self.user_pass[id] = self.user_pass.get(id, 0) + 1
continue
if id in self.user_pass:
del self.user_pass[id]
query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"]))
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"]))
update_transfer[id] = transfer
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when == '':
return update_transfer
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8',
ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]})
else:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8')
try:
cur = conn.cursor()
try:
cur.execute(query_sql)
except Exception as e:
logging.error(e)
update_transfer = {}
cur.close()
conn.commit()
except Exception as e:
logging.error(e)
update_transfer = {}
finally:
conn.close()
return update_transfer
def pull_db_all_user(self):
import cymysql
#数据库所有用户信息
if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8',
ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]})
else:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8')
try:
rows = self.pull_db_users(conn)
finally:
conn.close()
if not rows:
logging.warn('no user in db')
return rows
def pull_db_users(self, conn):
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
except Exception as e:
keys = self.key_list
cur = conn.cursor()
cur.execute("SELECT " + ','.join(keys) + " FROM user")
rows = []
for r in cur.fetchall():
d = {}
for column in range(len(keys)):
d[keys[column]] = r[column]
rows.append(d)
cur.close()
return rows
class Dbv3Transfer(DbTransfer):
def __init__(self):
super(Dbv3Transfer, self).__init__()
self.update_node_state = True if get_config().API_INTERFACE != 'legendsockssr' else False
if self.update_node_state:
self.key_list += ['id']
self.key_list += ['method']
if self.update_node_state:
self.ss_node_info_name = 'ss_node_info_log'
if get_config().API_INTERFACE == 'sspanelv3ssr':
self.key_list += ['obfs', 'protocol']
if get_config().API_INTERFACE == 'glzjinmod':
self.key_list += ['obfs', 'protocol']
self.ss_node_info_name = 'ss_node_info'
else:
self.key_list += ['obfs', 'protocol']
self.start_time = time.time()
def update_all_user(self, dt_transfer):
import cymysql
update_transfer = {}
query_head = 'UPDATE user'
query_sub_when = ''
query_sub_when2 = ''
query_sub_in = None
last_time = time.time()
alive_user_count = len(self.onlineuser_cache)
bandwidth_thistime = 0
if self.cfg["ssl_enable"] == 1:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8',
ssl={'ca':self.cfg["ssl_ca"],'cert':self.cfg["ssl_cert"],'key':self.cfg["ssl_key"]})
else:
conn = cymysql.connect(host=self.cfg["host"], port=self.cfg["port"],
user=self.cfg["user"], passwd=self.cfg["password"],
db=self.cfg["db"], charset='utf8')
conn.autocommit(True)
for id in dt_transfer.keys():
transfer = dt_transfer[id]
bandwidth_thistime = bandwidth_thistime + transfer[0] + transfer[1]
update_trs = 1024 * (2048 - self.user_pass.get(id, 0) * 64)
if transfer[0] + transfer[1] < update_trs:
self.user_pass[id] = self.user_pass.get(id, 0) + 1
continue
if id in self.user_pass:
del self.user_pass[id]
query_sub_when += ' WHEN %s THEN u+%s' % (id, int(transfer[0] * self.cfg["transfer_mul"]))
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, int(transfer[1] * self.cfg["transfer_mul"]))
update_transfer[id] = transfer
if self.update_node_state:
cur = conn.cursor()
try:
if id in self.port_uid_table:
cur.execute("INSERT INTO `user_traffic_log` (`id`, `user_id`, `u`, `d`, `node_id`, `rate`, `traffic`, `log_time`) VALUES (NULL, '" + \
str(self.port_uid_table[id]) + "', '" + str(transfer[0]) + "', '" + str(transfer[1]) + "', '" + \
str(self.cfg["node_id"]) + "', '" + str(self.cfg["transfer_mul"]) + "', '" + \
self.traffic_format((transfer[0] + transfer[1]) * self.cfg["transfer_mul"]) + "', unix_timestamp()); ")
except:
logging.warn('no `user_traffic_log` in db')
cur.close()
if query_sub_in is not None:
query_sub_in += ',%s' % id
else:
query_sub_in = '%s' % id
if query_sub_when != '':
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in
cur = conn.cursor()
try:
cur.execute(query_sql)
except Exception as e:
logging.error(e)
cur.close()
if self.update_node_state:
try:
cur = conn.cursor()
try:
cur.execute("INSERT INTO `ss_node_online_log` (`id`, `node_id`, `online_user`, `log_time`) VALUES (NULL, '" + \
str(self.cfg["node_id"]) + "', '" + str(alive_user_count) + "', unix_timestamp()); ")
except Exception as e:
logging.error(e)
cur.close()
cur = conn.cursor()
try:
cur.execute("INSERT INTO `" + self.ss_node_info_name + "` (`id`, `node_id`, `uptime`, `load`, `log_time`) VALUES (NULL, '" + \
str(self.cfg["node_id"]) + "', '" + str(self.uptime()) + "', '" + \
str(self.load()) + "', unix_timestamp()); ")
except Exception as e:
logging.error(e)
cur.close()
except:
logging.warn('no `ss_node_online_log` or `" + self.ss_node_info_name + "` in db')
conn.close()
return update_transfer
def pull_db_users(self, conn):
try:
switchrule = importloader.load('switchrule')
keys = switchrule.getKeys(self.key_list)
except Exception as e:
keys = self.key_list
cur = conn.cursor()
if self.update_node_state:
node_info_keys = ['traffic_rate']
try:
cur.execute("SELECT " + ','.join(node_info_keys) +" FROM ss_node where `id`='" + str(self.cfg["node_id"]) + "'")
nodeinfo = cur.fetchone()
except Exception as e:
logging.error(e)
nodeinfo = None
if nodeinfo == None:
rows = []
cur.close()
conn.commit()
logging.warn('None result when select node info from ss_node in db, maybe you set the incorrect node id')
return rows
cur.close()
node_info_dict = {}
for column in range(len(nodeinfo)):
node_info_dict[node_info_keys[column]] = nodeinfo[column]
self.cfg['transfer_mul'] = float(node_info_dict['traffic_rate'])
cur = conn.cursor()
try:
rows = []
cur.execute("SELECT " + ','.join(keys) + " FROM user")
for r in cur.fetchall():
d = {}
for column in range(len(keys)):
d[keys[column]] = r[column]
rows.append(d)
except Exception as e:
logging.error(e)
cur.close()
return rows
def load(self):
import os
return os.popen("cat /proc/loadavg | awk '{ print $1\" \"$2\" \"$3 }'").readlines()[0]
def uptime(self):
return time.time() - self.start_time
def traffic_format(self, traffic):
if traffic < 1024 * 8:
return str(int(traffic)) + "B";
if traffic < 1024 * 1024 * 2:
return str(round((traffic / 1024.0), 2)) + "KB";
return str(round((traffic / 1048576.0), 2)) + "MB";
class MuJsonTransfer(TransferBase):
def __init__(self):
super(MuJsonTransfer, self).__init__()
def update_all_user(self, dt_transfer):
import json
rows = None
config_path = get_config().MUDB_FILE
with open(config_path, 'rb+') as f:
rows = json.loads(f.read().decode('utf8'))
for row in rows:
if "port" in row:
port = row["port"]
if port in dt_transfer:
row["u"] += dt_transfer[port][0]
row["d"] += dt_transfer[port][1]
if rows:
output = json.dumps(rows, sort_keys=True, indent=4, separators=(',', ': '))
with open(config_path, 'r+') as f:
f.write(output)
f.truncate()
return dt_transfer
def pull_db_all_user(self):
import json
rows = None
config_path = get_config().MUDB_FILE
with open(config_path, 'rb+') as f:
rows = json.loads(f.read().decode('utf8'))
for row in rows:
try:
if 'forbidden_ip' in row:
row['forbidden_ip'] = common.IPNetwork(row['forbidden_ip'])
except Exception as e:
logging.error(e)
try:
if 'forbidden_port' in row:
row['forbidden_port'] = common.PortRange(row['forbidden_port'])
except Exception as e:
logging.error(e)
if not rows:
logging.warn('no user in json file')
return rows