use db key

This commit is contained in:
breakwa11 2015-06-10 12:35:49 +08:00
parent c27f6283a2
commit 7c2fe9fd56
4 changed files with 288 additions and 288 deletions

View file

@ -7,142 +7,145 @@ import time
import sys import sys
from server_pool import ServerPool from server_pool import ServerPool
import Config import Config
import traceback
class DbTransfer(object): class DbTransfer(object):
instance = None instance = None
def __init__(self): def __init__(self):
self.last_get_transfer = {} self.last_get_transfer = {}
@staticmethod @staticmethod
def get_instance(): def get_instance():
if DbTransfer.instance is None: if DbTransfer.instance is None:
DbTransfer.instance = DbTransfer() DbTransfer.instance = DbTransfer()
return DbTransfer.instance return DbTransfer.instance
def push_db_all_user(self): def push_db_all_user(self):
#更新用户流量到数据库 #更新用户流量到数据库
last_transfer = self.last_get_transfer last_transfer = self.last_get_transfer
curr_transfer = ServerPool.get_instance().get_servers_transfer() curr_transfer = ServerPool.get_instance().get_servers_transfer()
#上次和本次的增量 #上次和本次的增量
dt_transfer = {} dt_transfer = {}
for id in curr_transfer.keys(): for id in curr_transfer.keys():
if id in last_transfer: if id in last_transfer:
if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]: if last_transfer[id][0] == curr_transfer[id][0] and last_transfer[id][1] == curr_transfer[id][1]:
continue continue
elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: elif curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue continue
elif last_transfer[id][0] <= curr_transfer[id][0] and \ elif last_transfer[id][0] <= curr_transfer[id][0] and \
last_transfer[id][1] <= curr_transfer[id][1]: last_transfer[id][1] <= curr_transfer[id][1]:
dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0], dt_transfer[id] = [curr_transfer[id][0] - last_transfer[id][0],
curr_transfer[id][1] - last_transfer[id][1]] curr_transfer[id][1] - last_transfer[id][1]]
else: else:
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
else: else:
if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0: if curr_transfer[id][0] == 0 and curr_transfer[id][1] == 0:
continue continue
dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]] dt_transfer[id] = [curr_transfer[id][0], curr_transfer[id][1]]
self.last_get_transfer = curr_transfer self.last_get_transfer = curr_transfer
query_head = 'UPDATE user' query_head = 'UPDATE user'
query_sub_when = '' query_sub_when = ''
query_sub_when2 = '' query_sub_when2 = ''
query_sub_in = None query_sub_in = None
last_time = time.time() last_time = time.time()
for id in dt_transfer.keys(): for id in dt_transfer.keys():
query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0]) query_sub_when += ' WHEN %s THEN u+%s' % (id, dt_transfer[id][0])
query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1]) query_sub_when2 += ' WHEN %s THEN d+%s' % (id, dt_transfer[id][1])
if query_sub_in is not None: if query_sub_in is not None:
query_sub_in += ',%s' % id query_sub_in += ',%s' % id
else: else:
query_sub_in = '%s' % id query_sub_in = '%s' % id
if query_sub_when == '': if query_sub_when == '':
return return
query_sql = query_head + ' SET u = CASE port' + query_sub_when + \ query_sql = query_head + ' SET u = CASE port' + query_sub_when + \
' END, d = CASE port' + query_sub_when2 + \ ' END, d = CASE port' + query_sub_when2 + \
' END, t = ' + str(int(last_time)) + \ ' END, t = ' + str(int(last_time)) + \
' WHERE port IN (%s)' % query_sub_in ' WHERE port IN (%s)' % query_sub_in
#print query_sql #print query_sql
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur = conn.cursor() cur = conn.cursor()
cur.execute(query_sql) cur.execute(query_sql)
cur.close() cur.close()
conn.commit() conn.commit()
conn.close() conn.close()
@staticmethod @staticmethod
def pull_db_all_user(): def pull_db_all_user():
#数据库所有用户信息 #数据库所有用户信息
conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER, keys = ['port', 'u', 'd', 'transfer_enable', 'passwd', 'switch', 'enable', 'plan' ]
passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8') conn = cymysql.connect(host=Config.MYSQL_HOST, port=Config.MYSQL_PORT, user=Config.MYSQL_USER,
cur = conn.cursor() passwd=Config.MYSQL_PASS, db=Config.MYSQL_DB, charset='utf8')
cur.execute("SELECT port, u, d, transfer_enable, passwd, switch, enable, plan FROM user") cur = conn.cursor()
rows = [] cur.execute("SELECT " + ','.join(keys) + " FROM user")
for r in cur.fetchall(): rows = []
rows.append(list(r)) for r in cur.fetchall():
cur.close() d = {}
conn.close() for column in xrange(len(keys)):
return rows d[keys[column]] = r[column]
rows.append(d)
cur.close()
conn.close()
return rows
@staticmethod @staticmethod
def del_server_out_of_bound_safe(last_rows, rows): def del_server_out_of_bound_safe(last_rows, rows):
#停止超流量的服务 #停止超流量的服务
#启动没超流量的服务 #启动没超流量的服务
#需要动态载入switchrule以便实时修改规则 #需要动态载入switchrule以便实时修改规则
cur_servers = {} cur_servers = {}
for row in rows: for row in rows:
try: try:
import switchrule import switchrule
allow = switchrule.isTurnOn(row[7], row[5]) and row[6] == 1 and row[1] + row[2] < row[3] allow = switchrule.isTurnOn(row) and row['enable'] == 1 and row['u'] + row['d'] < row['transfer_enable']
except Exception, e: except Exception, e:
allow = False allow = False
cur_servers[row[0]] = row[4] port = row['port']
passwd = row['passwd']
cur_servers[port] = passwd
if ServerPool.get_instance().server_is_run(row[0]) > 0: if ServerPool.get_instance().server_is_run(port) > 0:
if not allow: if not allow:
logging.info('db stop server at port [%s]' % (row[0])) logging.info('db stop server at port [%s]' % (port,))
ServerPool.get_instance().del_server(row[0]) ServerPool.get_instance().del_server(port)
elif (row[0] in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[row[0]]._config['password'] != row[4]) \ elif (port in ServerPool.get_instance().tcp_servers_pool and ServerPool.get_instance().tcp_servers_pool[port]._config['password'] != passwd) \
or (row[0] in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[row[0]]._config['password'] != row[4]): or (port in ServerPool.get_instance().tcp_ipv6_servers_pool and ServerPool.get_instance().tcp_ipv6_servers_pool[port]._config['password'] != passwd):
#password changed #password changed
logging.info('db stop server at port [%s] reason: password changed' % (row[0])) logging.info('db stop server at port [%s] reason: password changed' % (port,))
ServerPool.get_instance().del_server(row[0]) ServerPool.get_instance().del_server(port)
elif ServerPool.get_instance().server_run_status(row[0]) is False:
if allow:
logging.info('db start server at port [%s] pass [%s]' % (row[0], row[4]))
ServerPool.get_instance().new_server(row[0], row[4])
for row in last_rows: if allow and ServerPool.get_instance().server_is_run(port) == 0:
if row[0] in cur_servers: logging.info('db start server at port [%s] pass [%s]' % (port, passwd))
if row[4] == cur_servers[row[0]]: ServerPool.get_instance().new_server(port, passwd)
pass
else:
logging.info('db stop server at port [%s] reason: port not exist' % (row[0]))
ServerPool.get_instance().del_server(row[0])
@staticmethod for row in last_rows:
def thread_db(): if row['port'] in cur_servers:
import socket pass
import time else:
timeout = 60 logging.info('db stop server at port [%s] reason: port not exist' % (row['port']))
socket.setdefaulttimeout(timeout) ServerPool.get_instance().del_server(row['port'])
last_rows = []
while True:
#logging.warn('db loop')
try: @staticmethod
DbTransfer.get_instance().push_db_all_user() def thread_db():
rows = DbTransfer.get_instance().pull_db_all_user() import socket
DbTransfer.del_server_out_of_bound_safe(last_rows, rows) import time
last_rows = rows timeout = 60
except Exception as e: socket.setdefaulttimeout(timeout)
logging.warn('db thread except:%s' % e) last_rows = []
finally: while True:
time.sleep(15) try:
DbTransfer.get_instance().push_db_all_user()
rows = DbTransfer.get_instance().pull_db_all_user()
DbTransfer.del_server_out_of_bound_safe(last_rows, rows)
last_rows = rows
except Exception as e:
trace = traceback.format_exc()
logging.error(trace)
#logging.warn('db thread except:%s' % e)
finally:
time.sleep(15)
#SQLData.pull_db_all_user()
#print DbTransfer.get_instance().test()

View file

@ -1,6 +1,4 @@
#!/usr/bin/env python #!/usr/bin/python
# -*- coding: utf-8 -*-
import time import time
import sys import sys
import thread import thread
@ -11,12 +9,12 @@ import server_pool
import db_transfer import db_transfer
#def test(): #def test():
# thread.start_new_thread(DbTransfer.thread_db, ()) # thread.start_new_thread(DbTransfer.thread_db, ())
# Api.web_server() # Api.web_server()
if __name__ == '__main__': if __name__ == '__main__':
#server_pool.ServerPool.get_instance() #server_pool.ServerPool.get_instance()
#server_pool.ServerPool.get_instance().new_server(2333, '2333') #server_pool.ServerPool.get_instance().new_server(2333, '2333')
thread.start_new_thread(db_transfer.DbTransfer.thread_db, ()) thread.start_new_thread(db_transfer.DbTransfer.thread_db, ())
while True: while True:
time.sleep(99999) time.sleep(99999)

View file

@ -24,7 +24,7 @@
import os import os
import logging import logging
import time import time
from shadowsocks import utils from shadowsocks import shell
from shadowsocks import eventloop from shadowsocks import eventloop
from shadowsocks import tcprelay from shadowsocks import tcprelay
from shadowsocks import udprelay from shadowsocks import udprelay
@ -38,174 +38,172 @@ from socket import *
class ServerPool(object): class ServerPool(object):
instance = None instance = None
def __init__(self): def __init__(self):
utils.check_python() shell.check_python()
self.config = utils.get_config(False) self.config = shell.get_config(False)
utils.print_shadowsocks() shell.print_shadowsocks()
self.dns_resolver = asyncdns.DNSResolver() self.dns_resolver = asyncdns.DNSResolver()
self.mgr = asyncmgr.ServerMgr() self.mgr = asyncmgr.ServerMgr()
self.udp_on = True ### UDP switch ===================================== self.udp_on = True ### UDP switch =====================================
self.tcp_servers_pool = {} self.tcp_servers_pool = {}
self.tcp_ipv6_servers_pool = {} self.tcp_ipv6_servers_pool = {}
self.udp_servers_pool = {} self.udp_servers_pool = {}
self.udp_ipv6_servers_pool = {} self.udp_ipv6_servers_pool = {}
self.loop = eventloop.EventLoop() self.loop = eventloop.EventLoop()
thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr)) thread.start_new_thread(ServerPool._loop, (self.loop, self.dns_resolver, self.mgr))
@staticmethod @staticmethod
def get_instance(): def get_instance():
if ServerPool.instance is None: if ServerPool.instance is None:
ServerPool.instance = ServerPool() ServerPool.instance = ServerPool()
return ServerPool.instance return ServerPool.instance
@staticmethod @staticmethod
def _loop(loop, dns_resolver, mgr): def _loop(loop, dns_resolver, mgr):
try: try:
mgr.add_to_loop(loop) mgr.add_to_loop(loop)
dns_resolver.add_to_loop(loop) dns_resolver.add_to_loop(loop)
loop.run() loop.run()
except (KeyboardInterrupt, IOError, OSError) as e: except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e) logging.error(e)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
os.exit(0) os.exit(0)
def server_is_run(self, port): def server_is_run(self, port):
port = int(port) port = int(port)
ret = 0 ret = 0
if port in self.tcp_servers_pool: if port in self.tcp_servers_pool:
ret = 1 ret = 1
if port in self.tcp_ipv6_servers_pool: if port in self.tcp_ipv6_servers_pool:
ret |= 2 ret |= 2
return ret return ret
def server_run_status(self, port): def server_run_status(self, port):
if 'server' in self.config: if 'server' in self.config:
if port not in self.tcp_servers_pool: if port not in self.tcp_servers_pool:
return False return False
if 'server_ipv6' in self.config: if 'server_ipv6' in self.config:
if port not in self.tcp_ipv6_servers_pool: if port not in self.tcp_ipv6_servers_pool:
return False return False
return True return True
def new_server(self, port, password): def new_server(self, port, password):
ret = True ret = True
port = int(port) port = int(port)
if 'server_ipv6' in self.config: if 'server_ipv6' in self.config:
if port in self.tcp_ipv6_servers_pool: if port in self.tcp_ipv6_servers_pool:
logging.info("server already at %s:%d" % (self.config['server_ipv6'], port)) logging.info("server already at %s:%d" % (self.config['server_ipv6'], port))
return 'this port server is already running' return 'this port server is already running'
else: else:
a_config = self.config.copy() a_config = self.config.copy()
a_config['server'] = a_config['server_ipv6'] a_config['server'] = a_config['server_ipv6']
a_config['server_port'] = port a_config['server_port'] = port
a_config['password'] = password a_config['password'] = password
try: try:
logging.info("starting server at %s:%d" % (a_config['server'], port)) logging.info("starting server at %s:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop) tcp_server.add_to_loop(self.loop)
self.tcp_ipv6_servers_pool.update({port: tcp_server}) self.tcp_ipv6_servers_pool.update({port: tcp_server})
if self.udp_on: if self.udp_on:
udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False)
udp_server.add_to_loop(self.loop) udp_server.add_to_loop(self.loop)
self.udp_ipv6_servers_pool.update({port: udp_server}) self.udp_ipv6_servers_pool.update({port: udp_server})
except Exception, e: except Exception, e:
logging.warn("IPV6 exception") logging.warn("IPV6 %s " % (e,))
logging.warn(e)
if 'server' in self.config: if 'server' in self.config:
if port in self.tcp_servers_pool: if port in self.tcp_servers_pool:
logging.info("server already at %s:%d" % (self.config['server'], port)) logging.info("server already at %s:%d" % (self.config['server'], port))
return 'this port server is already running' return 'this port server is already running'
else: else:
a_config = self.config.copy() a_config = self.config.copy()
a_config['server_port'] = port a_config['server_port'] = port
a_config['password'] = password a_config['password'] = password
try: try:
logging.info("starting server at %s:%d" % (a_config['server'], port)) logging.info("starting server at %s:%d" % (a_config['server'], port))
tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False)
tcp_server.add_to_loop(self.loop) tcp_server.add_to_loop(self.loop)
self.tcp_servers_pool.update({port: tcp_server}) self.tcp_servers_pool.update({port: tcp_server})
if self.udp_on: if self.udp_on:
udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False) udp_server = udprelay.UDPRelay(a_config, self.dns_resolver, False)
udp_server.add_to_loop(self.loop) udp_server.add_to_loop(self.loop)
self.udp_servers_pool.update({port: udp_server}) self.udp_servers_pool.update({port: udp_server})
except Exception, e: except Exception, e:
logging.warn("IPV4 exception") logging.warn("IPV4 %s " % (e,))
logging.warn(e)
return True return True
def del_server(self, port): def del_server(self, port):
port = int(port) port = int(port)
logging.info("del server at %d" % port) logging.info("del server at %d" % port)
try: try:
udpsock = socket(AF_INET, SOCK_DGRAM) udpsock = socket(AF_INET, SOCK_DGRAM)
udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT)) udpsock.sendto('%s:%s:0:0' % (Config.MANAGE_PASS, port), (Config.MANAGE_BIND_IP, Config.MANAGE_PORT))
udpsock.close() udpsock.close()
except Exception, e: except Exception, e:
logging.warn(e) logging.warn(e)
return True return True
def cb_del_server(self, port): def cb_del_server(self, port):
port = int(port) port = int(port)
if port not in self.tcp_servers_pool: if port not in self.tcp_servers_pool:
logging.info("stopped server at %s:%d already stop" % (self.config['server'], port)) logging.info("stopped server at %s:%d already stop" % (self.config['server'], port))
else: else:
logging.info("stopped server at %s:%d" % (self.config['server'], port)) logging.info("stopped server at %s:%d" % (self.config['server'], port))
try: try:
self.tcp_servers_pool[port].destroy() self.tcp_servers_pool[port].destroy()
del self.tcp_servers_pool[port] del self.tcp_servers_pool[port]
except Exception, e: except Exception, e:
logging.warn(e) logging.warn(e)
if self.udp_on: if self.udp_on:
try: try:
self.udp_servers_pool[port].destroy() self.udp_servers_pool[port].destroy()
del self.udp_servers_pool[port] del self.udp_servers_pool[port]
except Exception, e: except Exception, e:
logging.warn(e) logging.warn(e)
if 'server_ipv6' in self.config: if 'server_ipv6' in self.config:
if port not in self.tcp_ipv6_servers_pool: if port not in self.tcp_ipv6_servers_pool:
logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port)) logging.info("stopped server at %s:%d already stop" % (self.config['server_ipv6'], port))
else: else:
logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port)) logging.info("stopped server at %s:%d" % (self.config['server_ipv6'], port))
try: try:
self.tcp_ipv6_servers_pool[port].destroy() self.tcp_ipv6_servers_pool[port].destroy()
del self.tcp_ipv6_servers_pool[port] del self.tcp_ipv6_servers_pool[port]
except Exception, e: except Exception, e:
logging.warn(e) logging.warn(e)
if self.udp_on: if self.udp_on:
try: try:
self.udp_ipv6_servers_pool[port].destroy() self.udp_ipv6_servers_pool[port].destroy()
del self.udp_ipv6_servers_pool[port] del self.udp_ipv6_servers_pool[port]
except Exception, e: except Exception, e:
logging.warn(e) logging.warn(e)
return True return True
def get_server_transfer(self, port): def get_server_transfer(self, port):
port = int(port) port = int(port)
ret = [0, 0] ret = [0, 0]
if port in self.tcp_servers_pool: if port in self.tcp_servers_pool:
ret[0] = self.tcp_servers_pool[port].server_transfer_ul ret[0] = self.tcp_servers_pool[port].server_transfer_ul
ret[1] = self.tcp_servers_pool[port].server_transfer_dl ret[1] = self.tcp_servers_pool[port].server_transfer_dl
if port in self.tcp_ipv6_servers_pool: if port in self.tcp_ipv6_servers_pool:
ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul ret[0] += self.tcp_ipv6_servers_pool[port].server_transfer_ul
ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl ret[1] += self.tcp_ipv6_servers_pool[port].server_transfer_dl
return ret return ret
def get_servers_transfer(self): def get_servers_transfer(self):
servers = self.tcp_servers_pool.copy() servers = self.tcp_servers_pool.copy()
servers.update(self.tcp_ipv6_servers_pool) servers.update(self.tcp_ipv6_servers_pool)
ret = {} ret = {}
for port in servers.keys(): for port in servers.keys():
ret[port] = self.get_server_transfer(port) ret[port] = self.get_server_transfer(port)
return ret return ret

View file

@ -1,2 +1,3 @@
def isTurnOn(plan, switch): def isTurnOn(row):
return True return True