refine loop

This commit is contained in:
clowwindy 2014-06-01 19:09:52 +08:00
parent 5e19fdc66b
commit 0c8a8ef23f
5 changed files with 120 additions and 136 deletions

View file

@ -28,6 +28,8 @@
import os
import socket
import select
import errno
import logging
from collections import defaultdict
@ -154,25 +156,24 @@ class EventLoop(object):
def __init__(self):
if hasattr(select, 'epoll'):
self._impl = EpollLoop()
self._model = 'epoll'
model = 'epoll'
elif hasattr(select, 'kqueue'):
self._impl = KqueueLoop()
self._model = 'kqueue'
model = 'kqueue'
elif hasattr(select, 'select'):
self._impl = SelectLoop()
self._model = 'select'
model = 'select'
else:
raise Exception('can not find any available functions in select '
'package')
self._fd_to_f = {}
@property
def model(self):
return self._model
self._handlers = []
self.stopping = False
logging.debug('using event model: %s', model)
def poll(self, timeout=None):
events = self._impl.poll(timeout)
return ((self._fd_to_f[fd], event) for fd, event in events)
return [(self._fd_to_f[fd], fd, event) for fd, event in events]
def add(self, f, mode):
fd = f.fileno()
@ -188,6 +189,26 @@ class EventLoop(object):
fd = f.fileno()
self._impl.modify_fd(fd, mode)
def add_handler(self, handler):
self._handlers.append(handler)
def run(self):
while not self.stopping:
events = None
try:
events = self.poll(1)
except (OSError, IOError) as e:
if errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for handler in self._handlers:
# no exceptions should be raised by users
# TODO when there are a lot of handlers
handler(events)
# from tornado
def errno_from_exception(e):

View file

@ -24,8 +24,9 @@
import sys
import os
import logging
import encrypt
import utils
import encrypt
import eventloop
import tcprelay
import udprelay
@ -49,11 +50,12 @@ def main():
logging.info("starting local at %s:%d" %
(config['local_address'], config['local_port']))
# TODO combine the two threads into one loop on a single thread
udprelay.UDPRelay(config, True).start()
tcprelay.TCPRelay(config, True).start()
while sys.stdin.read():
pass
tcp_server = tcprelay.TCPRelay(config, True)
udp_server = udprelay.UDPRelay(config, True)
loop = eventloop.EventLoop()
tcp_server.add_to_loop(loop)
udp_server.add_to_loop(loop)
loop.run()
except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e)
os._exit(0)

View file

@ -22,11 +22,11 @@
# SOFTWARE.
import sys
import socket
import logging
import encrypt
import os
import logging
import utils
import encrypt
import eventloop
import tcprelay
import udprelay
@ -56,19 +56,17 @@ def main():
a_config['password'] = password
logging.info("starting server at %s:%d" %
(a_config['server'], int(port)))
tcp_server = tcprelay.TCPRelay(a_config, False)
tcp_servers.append(tcp_server)
udp_server = udprelay.UDPRelay(a_config, False)
udp_servers.append(udp_server)
tcp_servers.append(tcprelay.TCPRelay(a_config, False))
udp_servers.append(udprelay.UDPRelay(a_config, False))
def run_server():
try:
loop = eventloop.EventLoop()
for tcp_server in tcp_servers:
tcp_server.start()
tcp_server.add_to_loop(loop)
for udp_server in udp_servers:
udp_server.start()
while sys.stdin.read():
pass
udp_server.add_to_loop(loop)
loop.run()
except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e)
os._exit(0)
@ -96,10 +94,10 @@ def main():
signal.signal(signal.SIGTERM, handler)
# master
for tcp_server in tcp_servers:
tcp_server.close()
for udp_server in udp_servers:
udp_server.close()
for a_tcp_server in tcp_servers:
a_tcp_server.close()
for a_udp_server in udp_servers:
a_udp_server.close()
for child in children:
os.waitpid(child, 0)
@ -111,7 +109,4 @@ def main():
if __name__ == '__main__':
try:
main()
except socket.error, e:
logging.error(e)
main()

View file

@ -26,7 +26,6 @@ import socket
import logging
import encrypt
import errno
import threading
import eventloop
from common import parse_header
@ -303,6 +302,7 @@ class TCPRelayHandler(object):
logging.warn('unknown socket')
def destroy(self):
logging.debug('destroy')
if self._remote_sock:
self._loop.remove(self._remote_sock)
del self._fd_to_handlers[self._remote_sock.fileno()]
@ -320,8 +320,9 @@ class TCPRelay(object):
self._config = config
self._is_local = is_local
self._closed = False
self._thread = None
self._eventloop = None
self._fd_to_handlers = {}
self._last_time = time.time()
if is_local:
listen_addr = config['local_address']
@ -343,70 +344,48 @@ class TCPRelay(object):
server_socket.listen(1024)
self._server_socket = server_socket
def _run(self):
server_socket = self._server_socket
self._eventloop = eventloop.EventLoop()
logging.debug('using event model: %s', self._eventloop.model)
self._eventloop.add(server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
last_time = time.time()
while not self._closed:
try:
events = self._eventloop.poll(1)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events:
if sock:
logging.debug('fd %d %s', sock.fileno(),
eventloop.EVENT_NAMES[event])
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
conn = self._server_socket.accept()
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
continue
else:
logging.error(e)
else:
if sock:
handler = self._fd_to_handlers.get(sock.fileno(), None)
if handler:
handler.handle_event(sock, event)
else:
logging.warn('can not find handler for fd %d',
sock.fileno())
self._eventloop.remove(sock)
else:
logging.warn('poll removed fd')
now = time.time()
if now - last_time > 5:
# TODO sweep timeouts
last_time = now
def start(self):
# TODO combine loops on multiple ports into one single loop
def add_to_loop(self, loop):
if self._closed:
raise Exception('closed')
t = threading.Thread(target=self._run)
t.setName('TCPThread')
t.setDaemon(False)
t.start()
self._thread = t
raise Exception('already closed')
self._eventloop = loop
loop.add_handler(self._handle_events)
self._eventloop.add(self._server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
def _handle_events(self, events):
for sock, fd, event in events:
if sock:
logging.debug('fd %d %s', fd,
eventloop.EVENT_NAMES.get(event, event))
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
continue
else:
logging.error(e)
else:
if sock:
handler = self._fd_to_handlers.get(fd, None)
if handler:
handler.handle_event(sock, event)
else:
logging.warn('poll removed fd')
now = time.time()
if now - self._last_time > 5:
# TODO sweep timeouts
self._last_time = now
def close(self):
self._closed = True
self._server_socket.close()
def thread(self):
return self._thread

View file

@ -67,7 +67,6 @@
import time
import threading
import socket
import logging
import struct
@ -105,8 +104,10 @@ class UDPRelay(object):
close_callback=self._close_client)
self._client_fd_to_server_addr = \
lru_cache.LRUCache(timeout=config['timeout'])
self._eventloop = None
self._closed = False
self._thread = None
self._last_time = time.time()
self._sockets = set()
addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP)
@ -121,6 +122,7 @@ class UDPRelay(object):
def _close_client(self, client):
if hasattr(client, 'close'):
self._sockets.remove(client.fileno())
self._eventloop.remove(client)
client.close()
else:
@ -167,6 +169,7 @@ class UDPRelay(object):
else:
# drop
return
self._sockets.add(client.fileno())
self._eventloop.add(client, eventloop.POLL_IN)
data = data[header_length:]
@ -216,45 +219,29 @@ class UDPRelay(object):
# simply drop that packet
pass
def _run(self):
server_socket = self._server_socket
self._eventloop = eventloop.EventLoop()
self._eventloop.add(server_socket, eventloop.POLL_IN)
last_time = time.time()
while not self._closed:
try:
events = self._eventloop.poll(10)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events:
if sock == self._server_socket:
self._handle_server()
else:
self._handle_client(sock)
now = time.time()
if now - last_time > 3.5:
self._cache.sweep()
if now - last_time > 7:
self._client_fd_to_server_addr.sweep()
last_time = now
def start(self):
def add_to_loop(self, loop):
if self._closed:
raise Exception('closed')
t = threading.Thread(target=self._run)
t.setName('UDPThread')
t.setDaemon(False)
t.start()
self._thread = t
raise Exception('already closed')
self._eventloop = loop
loop.add_handler(self._handle_events)
server_socket = self._server_socket
self._eventloop.add(server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
def _handle_events(self, events):
for sock, fd, event in events:
if sock == self._server_socket:
self._handle_server()
elif sock and (fd in self._sockets):
self._handle_client(sock)
now = time.time()
if now - self._last_time > 3.5:
self._cache.sweep()
if now - self._last_time > 7:
self._client_fd_to_server_addr.sweep()
self._last_time = now
def close(self):
self._closed = True
self._server_socket.close()
def thread(self):
return self._thread