bandwidth limit of the user
This commit is contained in:
parent
c4c1cd6dc1
commit
f296751def
3 changed files with 36 additions and 9 deletions
|
@ -14,6 +14,7 @@
|
|||
"obfs": "tls1.2_ticket_auth_compatible",
|
||||
"obfs_param": "",
|
||||
"speed_limit_per_con": 0,
|
||||
"speed_limit_per_user": 0,
|
||||
|
||||
"dns_ipv6": false,
|
||||
"connect_verbose_info": 0,
|
||||
|
|
|
@ -53,7 +53,7 @@ EVENT_NAMES = {
|
|||
}
|
||||
|
||||
# we check timeouts every TIMEOUT_PRECISION seconds
|
||||
TIMEOUT_PRECISION = 5
|
||||
TIMEOUT_PRECISION = 2
|
||||
|
||||
|
||||
class KqueueLoop(object):
|
||||
|
|
|
@ -133,6 +133,7 @@ class TCPRelayHandler(object):
|
|||
self._client_address = local_sock.getpeername()[:2]
|
||||
self._accept_address = local_sock.getsockname()[:2]
|
||||
self._user = None
|
||||
self._user_id = server._listen_port
|
||||
|
||||
# TCP Relay works as either sslocal or ssserver
|
||||
# if is_local, this is sslocal
|
||||
|
@ -238,6 +239,7 @@ class TCPRelayHandler(object):
|
|||
|
||||
def _update_user(self, user):
|
||||
self._user = user
|
||||
self._user_id = struct.unpack('<I', user)[0]
|
||||
|
||||
def _update_activity(self, data_len=0):
|
||||
# tell the TCP Relay we have activities recently
|
||||
|
@ -754,6 +756,7 @@ class TCPRelayHandler(object):
|
|||
return
|
||||
|
||||
self.speed_tester_u.add(len(data))
|
||||
self._server.speed_tester_u(self._user_id).add(len(data))
|
||||
ogn_data = data
|
||||
if not is_local:
|
||||
if self._encryptor is not None:
|
||||
|
@ -850,6 +853,7 @@ class TCPRelayHandler(object):
|
|||
return
|
||||
|
||||
self.speed_tester_d.add(len(data))
|
||||
self._server.speed_tester_d(self._user_id).add(len(data))
|
||||
if self._encryptor is not None:
|
||||
if self._is_local:
|
||||
try:
|
||||
|
@ -947,6 +951,7 @@ class TCPRelayHandler(object):
|
|||
return True
|
||||
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
|
||||
if not self.speed_tester_d.isExceed():
|
||||
if not self._server.speed_tester_d(self._user_id).isExceed():
|
||||
handle = True
|
||||
self._on_remote_read(sock == self._remote_sock)
|
||||
if self._stage == STAGE_DESTROYED:
|
||||
|
@ -962,6 +967,7 @@ class TCPRelayHandler(object):
|
|||
return True
|
||||
if event & (eventloop.POLL_IN | eventloop.POLL_HUP):
|
||||
if not self.speed_tester_u.isExceed():
|
||||
if not self._server.speed_tester_u(self._user_id).isExceed():
|
||||
handle = True
|
||||
self._on_local_read()
|
||||
if self._stage == STAGE_DESTROYED:
|
||||
|
@ -1048,6 +1054,9 @@ class TCPRelay(object):
|
|||
self.server_users = {}
|
||||
self.server_user_transfer_ul = {}
|
||||
self.server_user_transfer_dl = {}
|
||||
self.mu = False
|
||||
self._speed_tester_u = {}
|
||||
self._speed_tester_d = {}
|
||||
self.update_users_protocol_param = None
|
||||
self.update_users_acl = None
|
||||
self.server_connections = 0
|
||||
|
@ -1122,6 +1131,7 @@ class TCPRelay(object):
|
|||
protocol_param = self._config['protocol_param']
|
||||
param = common.to_bytes(protocol_param).split(b'#')
|
||||
if len(param) == 2:
|
||||
self.mu = True
|
||||
user_list = param[1].split(b',')
|
||||
if user_list:
|
||||
for user in user_list:
|
||||
|
@ -1164,6 +1174,22 @@ class TCPRelay(object):
|
|||
self.server_user_transfer_dl[user] += transfer + self.server_transfer_dl
|
||||
self.server_transfer_dl = 0
|
||||
|
||||
def speed_tester_u(self, uid):
|
||||
if uid not in self._speed_tester_u:
|
||||
if self.mu: #TODO
|
||||
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
|
||||
else:
|
||||
self._speed_tester_u[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
|
||||
return self._speed_tester_u[uid]
|
||||
|
||||
def speed_tester_d(self, uid):
|
||||
if uid not in self._speed_tester_d:
|
||||
if self.mu: #TODO
|
||||
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
|
||||
else:
|
||||
self._speed_tester_d[uid] = SpeedTester(self._config.get("speed_limit_per_user", 0))
|
||||
return self._speed_tester_d[uid]
|
||||
|
||||
def update_stat(self, port, stat_dict, val):
|
||||
newval = stat_dict.get(0, 0) + val
|
||||
stat_dict[0] = newval
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue