refine loop
This commit is contained in:
parent
5e19fdc66b
commit
0c8a8ef23f
5 changed files with 120 additions and 136 deletions
|
@ -28,6 +28,8 @@
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import select
|
import select
|
||||||
|
import errno
|
||||||
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,25 +156,24 @@ class EventLoop(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if hasattr(select, 'epoll'):
|
if hasattr(select, 'epoll'):
|
||||||
self._impl = EpollLoop()
|
self._impl = EpollLoop()
|
||||||
self._model = 'epoll'
|
model = 'epoll'
|
||||||
elif hasattr(select, 'kqueue'):
|
elif hasattr(select, 'kqueue'):
|
||||||
self._impl = KqueueLoop()
|
self._impl = KqueueLoop()
|
||||||
self._model = 'kqueue'
|
model = 'kqueue'
|
||||||
elif hasattr(select, 'select'):
|
elif hasattr(select, 'select'):
|
||||||
self._impl = SelectLoop()
|
self._impl = SelectLoop()
|
||||||
self._model = 'select'
|
model = 'select'
|
||||||
else:
|
else:
|
||||||
raise Exception('can not find any available functions in select '
|
raise Exception('can not find any available functions in select '
|
||||||
'package')
|
'package')
|
||||||
self._fd_to_f = {}
|
self._fd_to_f = {}
|
||||||
|
self._handlers = []
|
||||||
@property
|
self.stopping = False
|
||||||
def model(self):
|
logging.debug('using event model: %s', model)
|
||||||
return self._model
|
|
||||||
|
|
||||||
def poll(self, timeout=None):
|
def poll(self, timeout=None):
|
||||||
events = self._impl.poll(timeout)
|
events = self._impl.poll(timeout)
|
||||||
return ((self._fd_to_f[fd], event) for fd, event in events)
|
return [(self._fd_to_f[fd], fd, event) for fd, event in events]
|
||||||
|
|
||||||
def add(self, f, mode):
|
def add(self, f, mode):
|
||||||
fd = f.fileno()
|
fd = f.fileno()
|
||||||
|
@ -188,6 +189,26 @@ class EventLoop(object):
|
||||||
fd = f.fileno()
|
fd = f.fileno()
|
||||||
self._impl.modify_fd(fd, mode)
|
self._impl.modify_fd(fd, mode)
|
||||||
|
|
||||||
|
def add_handler(self, handler):
|
||||||
|
self._handlers.append(handler)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while not self.stopping:
|
||||||
|
events = None
|
||||||
|
try:
|
||||||
|
events = self.poll(1)
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
if errno_from_exception(e) == errno.EPIPE:
|
||||||
|
# Happens when the client closes the connection
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logging.error(e)
|
||||||
|
continue
|
||||||
|
for handler in self._handlers:
|
||||||
|
# no exceptions should be raised by users
|
||||||
|
# TODO when there are a lot of handlers
|
||||||
|
handler(events)
|
||||||
|
|
||||||
|
|
||||||
# from tornado
|
# from tornado
|
||||||
def errno_from_exception(e):
|
def errno_from_exception(e):
|
||||||
|
|
|
@ -24,8 +24,9 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import encrypt
|
|
||||||
import utils
|
import utils
|
||||||
|
import encrypt
|
||||||
|
import eventloop
|
||||||
import tcprelay
|
import tcprelay
|
||||||
import udprelay
|
import udprelay
|
||||||
|
|
||||||
|
@ -49,11 +50,12 @@ def main():
|
||||||
logging.info("starting local at %s:%d" %
|
logging.info("starting local at %s:%d" %
|
||||||
(config['local_address'], config['local_port']))
|
(config['local_address'], config['local_port']))
|
||||||
|
|
||||||
# TODO combine the two threads into one loop on a single thread
|
tcp_server = tcprelay.TCPRelay(config, True)
|
||||||
udprelay.UDPRelay(config, True).start()
|
udp_server = udprelay.UDPRelay(config, True)
|
||||||
tcprelay.TCPRelay(config, True).start()
|
loop = eventloop.EventLoop()
|
||||||
while sys.stdin.read():
|
tcp_server.add_to_loop(loop)
|
||||||
pass
|
udp_server.add_to_loop(loop)
|
||||||
|
loop.run()
|
||||||
except (KeyboardInterrupt, IOError, OSError) as e:
|
except (KeyboardInterrupt, IOError, OSError) as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
|
@ -22,11 +22,11 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import socket
|
|
||||||
import logging
|
|
||||||
import encrypt
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import utils
|
import utils
|
||||||
|
import encrypt
|
||||||
|
import eventloop
|
||||||
import tcprelay
|
import tcprelay
|
||||||
import udprelay
|
import udprelay
|
||||||
|
|
||||||
|
@ -56,19 +56,17 @@ def main():
|
||||||
a_config['password'] = password
|
a_config['password'] = password
|
||||||
logging.info("starting server at %s:%d" %
|
logging.info("starting server at %s:%d" %
|
||||||
(a_config['server'], int(port)))
|
(a_config['server'], int(port)))
|
||||||
tcp_server = tcprelay.TCPRelay(a_config, False)
|
tcp_servers.append(tcprelay.TCPRelay(a_config, False))
|
||||||
tcp_servers.append(tcp_server)
|
udp_servers.append(udprelay.UDPRelay(a_config, False))
|
||||||
udp_server = udprelay.UDPRelay(a_config, False)
|
|
||||||
udp_servers.append(udp_server)
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
try:
|
try:
|
||||||
|
loop = eventloop.EventLoop()
|
||||||
for tcp_server in tcp_servers:
|
for tcp_server in tcp_servers:
|
||||||
tcp_server.start()
|
tcp_server.add_to_loop(loop)
|
||||||
for udp_server in udp_servers:
|
for udp_server in udp_servers:
|
||||||
udp_server.start()
|
udp_server.add_to_loop(loop)
|
||||||
while sys.stdin.read():
|
loop.run()
|
||||||
pass
|
|
||||||
except (KeyboardInterrupt, IOError, OSError) as e:
|
except (KeyboardInterrupt, IOError, OSError) as e:
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
@ -96,10 +94,10 @@ def main():
|
||||||
signal.signal(signal.SIGTERM, handler)
|
signal.signal(signal.SIGTERM, handler)
|
||||||
|
|
||||||
# master
|
# master
|
||||||
for tcp_server in tcp_servers:
|
for a_tcp_server in tcp_servers:
|
||||||
tcp_server.close()
|
a_tcp_server.close()
|
||||||
for udp_server in udp_servers:
|
for a_udp_server in udp_servers:
|
||||||
udp_server.close()
|
a_udp_server.close()
|
||||||
|
|
||||||
for child in children:
|
for child in children:
|
||||||
os.waitpid(child, 0)
|
os.waitpid(child, 0)
|
||||||
|
@ -111,7 +109,4 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
try:
|
main()
|
||||||
main()
|
|
||||||
except socket.error, e:
|
|
||||||
logging.error(e)
|
|
||||||
|
|
|
@ -26,7 +26,6 @@ import socket
|
||||||
import logging
|
import logging
|
||||||
import encrypt
|
import encrypt
|
||||||
import errno
|
import errno
|
||||||
import threading
|
|
||||||
import eventloop
|
import eventloop
|
||||||
from common import parse_header
|
from common import parse_header
|
||||||
|
|
||||||
|
@ -303,6 +302,7 @@ class TCPRelayHandler(object):
|
||||||
logging.warn('unknown socket')
|
logging.warn('unknown socket')
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
|
logging.debug('destroy')
|
||||||
if self._remote_sock:
|
if self._remote_sock:
|
||||||
self._loop.remove(self._remote_sock)
|
self._loop.remove(self._remote_sock)
|
||||||
del self._fd_to_handlers[self._remote_sock.fileno()]
|
del self._fd_to_handlers[self._remote_sock.fileno()]
|
||||||
|
@ -320,8 +320,9 @@ class TCPRelay(object):
|
||||||
self._config = config
|
self._config = config
|
||||||
self._is_local = is_local
|
self._is_local = is_local
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._thread = None
|
self._eventloop = None
|
||||||
self._fd_to_handlers = {}
|
self._fd_to_handlers = {}
|
||||||
|
self._last_time = time.time()
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
listen_addr = config['local_address']
|
listen_addr = config['local_address']
|
||||||
|
@ -343,70 +344,48 @@ class TCPRelay(object):
|
||||||
server_socket.listen(1024)
|
server_socket.listen(1024)
|
||||||
self._server_socket = server_socket
|
self._server_socket = server_socket
|
||||||
|
|
||||||
def _run(self):
|
def add_to_loop(self, loop):
|
||||||
server_socket = self._server_socket
|
|
||||||
self._eventloop = eventloop.EventLoop()
|
|
||||||
logging.debug('using event model: %s', self._eventloop.model)
|
|
||||||
self._eventloop.add(server_socket,
|
|
||||||
eventloop.POLL_IN | eventloop.POLL_ERR)
|
|
||||||
last_time = time.time()
|
|
||||||
while not self._closed:
|
|
||||||
try:
|
|
||||||
events = self._eventloop.poll(1)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
if eventloop.errno_from_exception(e) == errno.EPIPE:
|
|
||||||
# Happens when the client closes the connection
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logging.error(e)
|
|
||||||
continue
|
|
||||||
for sock, event in events:
|
|
||||||
if sock:
|
|
||||||
logging.debug('fd %d %s', sock.fileno(),
|
|
||||||
eventloop.EVENT_NAMES[event])
|
|
||||||
if sock == self._server_socket:
|
|
||||||
if event & eventloop.POLL_ERR:
|
|
||||||
# TODO
|
|
||||||
raise Exception('server_socket error')
|
|
||||||
try:
|
|
||||||
conn = self._server_socket.accept()
|
|
||||||
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
|
|
||||||
conn[0], self._config, self._is_local)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
error_no = eventloop.errno_from_exception(e)
|
|
||||||
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logging.error(e)
|
|
||||||
else:
|
|
||||||
if sock:
|
|
||||||
handler = self._fd_to_handlers.get(sock.fileno(), None)
|
|
||||||
if handler:
|
|
||||||
handler.handle_event(sock, event)
|
|
||||||
else:
|
|
||||||
logging.warn('can not find handler for fd %d',
|
|
||||||
sock.fileno())
|
|
||||||
self._eventloop.remove(sock)
|
|
||||||
else:
|
|
||||||
logging.warn('poll removed fd')
|
|
||||||
now = time.time()
|
|
||||||
if now - last_time > 5:
|
|
||||||
# TODO sweep timeouts
|
|
||||||
last_time = now
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
# TODO combine loops on multiple ports into one single loop
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise Exception('closed')
|
raise Exception('already closed')
|
||||||
t = threading.Thread(target=self._run)
|
self._eventloop = loop
|
||||||
t.setName('TCPThread')
|
loop.add_handler(self._handle_events)
|
||||||
t.setDaemon(False)
|
|
||||||
t.start()
|
self._eventloop.add(self._server_socket,
|
||||||
self._thread = t
|
eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||||
|
|
||||||
|
def _handle_events(self, events):
|
||||||
|
for sock, fd, event in events:
|
||||||
|
if sock:
|
||||||
|
logging.debug('fd %d %s', fd,
|
||||||
|
eventloop.EVENT_NAMES.get(event, event))
|
||||||
|
if sock == self._server_socket:
|
||||||
|
if event & eventloop.POLL_ERR:
|
||||||
|
# TODO
|
||||||
|
raise Exception('server_socket error')
|
||||||
|
try:
|
||||||
|
logging.debug('accept')
|
||||||
|
conn = self._server_socket.accept()
|
||||||
|
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
|
||||||
|
conn[0], self._config, self._is_local)
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
error_no = eventloop.errno_from_exception(e)
|
||||||
|
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logging.error(e)
|
||||||
|
else:
|
||||||
|
if sock:
|
||||||
|
handler = self._fd_to_handlers.get(fd, None)
|
||||||
|
if handler:
|
||||||
|
handler.handle_event(sock, event)
|
||||||
|
else:
|
||||||
|
logging.warn('poll removed fd')
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_time > 5:
|
||||||
|
# TODO sweep timeouts
|
||||||
|
self._last_time = now
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._server_socket.close()
|
self._server_socket.close()
|
||||||
|
|
||||||
def thread(self):
|
|
||||||
return self._thread
|
|
|
@ -67,7 +67,6 @@
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import threading
|
|
||||||
import socket
|
import socket
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
@ -105,8 +104,10 @@ class UDPRelay(object):
|
||||||
close_callback=self._close_client)
|
close_callback=self._close_client)
|
||||||
self._client_fd_to_server_addr = \
|
self._client_fd_to_server_addr = \
|
||||||
lru_cache.LRUCache(timeout=config['timeout'])
|
lru_cache.LRUCache(timeout=config['timeout'])
|
||||||
|
self._eventloop = None
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._thread = None
|
self._last_time = time.time()
|
||||||
|
self._sockets = set()
|
||||||
|
|
||||||
addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0,
|
addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0,
|
||||||
socket.SOCK_DGRAM, socket.SOL_UDP)
|
socket.SOCK_DGRAM, socket.SOL_UDP)
|
||||||
|
@ -121,6 +122,7 @@ class UDPRelay(object):
|
||||||
|
|
||||||
def _close_client(self, client):
|
def _close_client(self, client):
|
||||||
if hasattr(client, 'close'):
|
if hasattr(client, 'close'):
|
||||||
|
self._sockets.remove(client.fileno())
|
||||||
self._eventloop.remove(client)
|
self._eventloop.remove(client)
|
||||||
client.close()
|
client.close()
|
||||||
else:
|
else:
|
||||||
|
@ -167,6 +169,7 @@ class UDPRelay(object):
|
||||||
else:
|
else:
|
||||||
# drop
|
# drop
|
||||||
return
|
return
|
||||||
|
self._sockets.add(client.fileno())
|
||||||
self._eventloop.add(client, eventloop.POLL_IN)
|
self._eventloop.add(client, eventloop.POLL_IN)
|
||||||
|
|
||||||
data = data[header_length:]
|
data = data[header_length:]
|
||||||
|
@ -216,45 +219,29 @@ class UDPRelay(object):
|
||||||
# simply drop that packet
|
# simply drop that packet
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _run(self):
|
def add_to_loop(self, loop):
|
||||||
server_socket = self._server_socket
|
|
||||||
self._eventloop = eventloop.EventLoop()
|
|
||||||
self._eventloop.add(server_socket, eventloop.POLL_IN)
|
|
||||||
last_time = time.time()
|
|
||||||
while not self._closed:
|
|
||||||
try:
|
|
||||||
events = self._eventloop.poll(10)
|
|
||||||
except (OSError, IOError) as e:
|
|
||||||
if eventloop.errno_from_exception(e) == errno.EPIPE:
|
|
||||||
# Happens when the client closes the connection
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logging.error(e)
|
|
||||||
continue
|
|
||||||
for sock, event in events:
|
|
||||||
if sock == self._server_socket:
|
|
||||||
self._handle_server()
|
|
||||||
else:
|
|
||||||
self._handle_client(sock)
|
|
||||||
now = time.time()
|
|
||||||
if now - last_time > 3.5:
|
|
||||||
self._cache.sweep()
|
|
||||||
if now - last_time > 7:
|
|
||||||
self._client_fd_to_server_addr.sweep()
|
|
||||||
last_time = now
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise Exception('closed')
|
raise Exception('already closed')
|
||||||
t = threading.Thread(target=self._run)
|
self._eventloop = loop
|
||||||
t.setName('UDPThread')
|
loop.add_handler(self._handle_events)
|
||||||
t.setDaemon(False)
|
|
||||||
t.start()
|
server_socket = self._server_socket
|
||||||
self._thread = t
|
self._eventloop.add(server_socket,
|
||||||
|
eventloop.POLL_IN | eventloop.POLL_ERR)
|
||||||
|
|
||||||
|
def _handle_events(self, events):
|
||||||
|
for sock, fd, event in events:
|
||||||
|
if sock == self._server_socket:
|
||||||
|
self._handle_server()
|
||||||
|
elif sock and (fd in self._sockets):
|
||||||
|
self._handle_client(sock)
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_time > 3.5:
|
||||||
|
self._cache.sweep()
|
||||||
|
if now - self._last_time > 7:
|
||||||
|
self._client_fd_to_server_addr.sweep()
|
||||||
|
self._last_time = now
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._closed = True
|
self._closed = True
|
||||||
self._server_socket.close()
|
self._server_socket.close()
|
||||||
|
|
||||||
def thread(self):
|
|
||||||
return self._thread
|
|
||||||
|
|
Loading…
Reference in a new issue