fix transfer update
This commit is contained in:
parent
b1aade8640
commit
3ce6e6f714
3 changed files with 43 additions and 26 deletions
|
@ -18,12 +18,12 @@ class TransferBase(object):
|
|||
import threading
|
||||
self.event = threading.Event()
|
||||
self.key_list = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'enable']
|
||||
self.last_get_transfer = {}
|
||||
self.last_update_transfer = {}
|
||||
self.user_pass = {}
|
||||
self.port_uid_table = {}
|
||||
self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30)
|
||||
self.pull_ok = False
|
||||
self.last_get_transfer = {} #上一次的实际流量
|
||||
self.last_update_transfer = {} #上一次更新到的流量(小于等于实际流量)
|
||||
self.force_update_transfer = set() #强制推入数据库的ID
|
||||
self.port_uid_table = {} #端口到uid的映射(仅v3以上有用)
|
||||
self.onlineuser_cache = lru_cache.LRUCache(timeout=60*30) #用户在线状态记录
|
||||
self.pull_ok = False #记录是否已经拉出过数据
|
||||
|
||||
def load_cfg(self):
|
||||
pass
|
||||
|
@ -36,7 +36,19 @@ class TransferBase(object):
|
|||
curr_transfer = ServerPool.get_instance().get_servers_transfer()
|
||||
#上次和本次的增量
|
||||
dt_transfer = {}
|
||||
for id in self.force_update_transfer: #此表中的用户统计上次未计入的流量
|
||||
if id in self.last_get_transfer and id in last_transfer:
|
||||
dt_transfer[id] = [self.last_get_transfer[id][0] - last_transfer[id][0], self.last_get_transfer[id][1] - last_transfer[id][1]]
|
||||
|
||||
for id in curr_transfer.keys():
|
||||
#有流量的,先记录在线状态
|
||||
if id in self.last_get_transfer:
|
||||
if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]:
|
||||
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
|
||||
else:
|
||||
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
|
||||
|
||||
#算出与上次记录的流量差值,保存于dt_transfer表
|
||||
if id in last_transfer:
|
||||
if curr_transfer[id][0] + curr_transfer[id][1] - last_transfer[id][0] - last_transfer[id][1] <= 0:
|
||||
continue
|
||||
|
@ -50,17 +62,18 @@ class TransferBase(object):
|
|||
if curr_transfer[id][0] + curr_transfer[id][1] <= 0:
|
||||
continue
|
||||
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
|
||||
if id in self.last_get_transfer:
|
||||
if curr_transfer[id][0] + curr_transfer[id][1] > self.last_get_transfer[id][0] + self.last_get_transfer[id][1]:
|
||||
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
|
||||
else:
|
||||
self.onlineuser_cache[id] = curr_transfer[id][0] + curr_transfer[id][1]
|
||||
|
||||
self.onlineuser_cache.sweep()
|
||||
|
||||
update_transfer = self.update_all_user(dt_transfer)
|
||||
for id in update_transfer.keys():
|
||||
last = self.last_update_transfer.get(id, [0,0])
|
||||
self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]]
|
||||
update_transfer = self.update_all_user(dt_transfer) #返回有更新的表
|
||||
for id in update_transfer.keys(): #其增量加在此表
|
||||
if id in self.force_update_transfer: #但排除在force_update_transfer内的
|
||||
if id in self.last_update_transfer:
|
||||
del self.last_update_transfer[id]
|
||||
self.force_update_transfer.remove(id)
|
||||
else:
|
||||
last = self.last_update_transfer.get(id, [0,0])
|
||||
self.last_update_transfer[id] = [last[0] + update_transfer[id][0], last[1] + update_transfer[id][1]]
|
||||
self.last_get_transfer = curr_transfer
|
||||
|
||||
def del_server_out_of_bound_safe(self, last_rows, rows):
|
||||
|
@ -125,11 +138,7 @@ class TransferBase(object):
|
|||
new_servers[port] = (passwd, cfg)
|
||||
|
||||
elif allow and ServerPool.get_instance().server_run_status(port) is False:
|
||||
#new_servers[port] = passwd
|
||||
protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin'))
|
||||
obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain'))
|
||||
logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs))
|
||||
ServerPool.get_instance().new_server(port, cfg)
|
||||
self.new_server(port, passwd, cfg)
|
||||
|
||||
for row in last_rows:
|
||||
if row['port'] in cur_servers:
|
||||
|
@ -145,10 +154,15 @@ class TransferBase(object):
|
|||
self.event.wait(eventloop.TIMEOUT_PRECISION + eventloop.TIMEOUT_PRECISION / 2)
|
||||
for port in new_servers.keys():
|
||||
passwd, cfg = new_servers[port]
|
||||
protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin'))
|
||||
obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain'))
|
||||
logging.info('db start server at port [%s] pass [%s] protocol [%s] obfs [%s]' % (port, passwd, protocol, obfs))
|
||||
ServerPool.get_instance().new_server(port, cfg)
|
||||
self.new_server(port, passwd, cfg)
|
||||
|
||||
def new_server(self, port, passwd, cfg):
|
||||
protocol = cfg.get('protocol', ServerPool.get_instance().config.get('protocol', 'origin'))
|
||||
method = cfg.get('method', ServerPool.get_instance().config.get('method', 'None'))
|
||||
obfs = cfg.get('obfs', ServerPool.get_instance().config.get('obfs', 'plain'))
|
||||
logging.info('db start server at port [%s] pass [%s] protocol [%s] method [%s] obfs [%s]' % (port, passwd, protocol, method, obfs))
|
||||
ServerPool.get_instance().new_server(port, cfg)
|
||||
self.force_update_transfer.add(port)
|
||||
|
||||
def cmp(self, val1, val2):
|
||||
if type(val1) is bytes:
|
||||
|
@ -206,6 +220,7 @@ class TransferBase(object):
|
|||
class DbTransfer(TransferBase):
|
||||
def __init__(self):
|
||||
super(DbTransfer, self).__init__()
|
||||
self.user_pass = {} #记录更新此用户流量时被跳过多少次
|
||||
self.cfg = {
|
||||
"host": "127.0.0.1",
|
||||
"port": 3306,
|
||||
|
@ -242,6 +257,7 @@ class DbTransfer(TransferBase):
|
|||
|
||||
for id in dt_transfer.keys():
|
||||
transfer = dt_transfer[id]
|
||||
#小于最低更新流量的先不更新
|
||||
update_trs = 1024 * max(2048 - self.user_pass.get(id, 0) * 64, 16)
|
||||
if transfer[0] + transfer[1] < update_trs:
|
||||
continue
|
||||
|
|
|
@ -53,7 +53,7 @@ EVENT_NAMES = {
|
|||
}
|
||||
|
||||
# we check timeouts every TIMEOUT_PRECISION seconds
|
||||
TIMEOUT_PRECISION = 10
|
||||
TIMEOUT_PRECISION = 5
|
||||
|
||||
|
||||
class KqueueLoop(object):
|
||||
|
|
|
@ -100,7 +100,8 @@ class http_simple(plain.plain):
|
|||
hosts = (self.server_info.obfs_param or self.server_info.host)
|
||||
pos = hosts.find("#")
|
||||
if pos >= 0:
|
||||
body = hosts[pos + 1:].replace("\\n", "\r\n")
|
||||
body = hosts[pos + 1:].replace("\n", "\r\n")
|
||||
body = body.replace("\\n", "\r\n")
|
||||
hosts = hosts[:pos]
|
||||
hosts = hosts.split(',')
|
||||
host = random.choice(hosts)
|
||||
|
|
Loading…
Add table
Reference in a new issue