refine loop

This commit is contained in:
clowwindy 2014-06-01 19:09:52 +08:00
parent 5e19fdc66b
commit 0c8a8ef23f
5 changed files with 120 additions and 136 deletions

View file

@ -28,6 +28,8 @@
import os import os
import socket import socket
import select import select
import errno
import logging
from collections import defaultdict from collections import defaultdict
@ -154,25 +156,24 @@ class EventLoop(object):
def __init__(self): def __init__(self):
if hasattr(select, 'epoll'): if hasattr(select, 'epoll'):
self._impl = EpollLoop() self._impl = EpollLoop()
self._model = 'epoll' model = 'epoll'
elif hasattr(select, 'kqueue'): elif hasattr(select, 'kqueue'):
self._impl = KqueueLoop() self._impl = KqueueLoop()
self._model = 'kqueue' model = 'kqueue'
elif hasattr(select, 'select'): elif hasattr(select, 'select'):
self._impl = SelectLoop() self._impl = SelectLoop()
self._model = 'select' model = 'select'
else: else:
raise Exception('can not find any available functions in select ' raise Exception('can not find any available functions in select '
'package') 'package')
self._fd_to_f = {} self._fd_to_f = {}
self._handlers = []
@property self.stopping = False
def model(self): logging.debug('using event model: %s', model)
return self._model
def poll(self, timeout=None): def poll(self, timeout=None):
events = self._impl.poll(timeout) events = self._impl.poll(timeout)
return ((self._fd_to_f[fd], event) for fd, event in events) return [(self._fd_to_f[fd], fd, event) for fd, event in events]
def add(self, f, mode): def add(self, f, mode):
fd = f.fileno() fd = f.fileno()
@ -188,6 +189,26 @@ class EventLoop(object):
fd = f.fileno() fd = f.fileno()
self._impl.modify_fd(fd, mode) self._impl.modify_fd(fd, mode)
def add_handler(self, handler):
self._handlers.append(handler)
def run(self):
while not self.stopping:
events = None
try:
events = self.poll(1)
except (OSError, IOError) as e:
if errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for handler in self._handlers:
# no exceptions should be raised by users
# TODO when there are a lot of handlers
handler(events)
# from tornado # from tornado
def errno_from_exception(e): def errno_from_exception(e):

View file

@ -24,8 +24,9 @@
import sys import sys
import os import os
import logging import logging
import encrypt
import utils import utils
import encrypt
import eventloop
import tcprelay import tcprelay
import udprelay import udprelay
@ -49,11 +50,12 @@ def main():
logging.info("starting local at %s:%d" % logging.info("starting local at %s:%d" %
(config['local_address'], config['local_port'])) (config['local_address'], config['local_port']))
# TODO combine the two threads into one loop on a single thread tcp_server = tcprelay.TCPRelay(config, True)
udprelay.UDPRelay(config, True).start() udp_server = udprelay.UDPRelay(config, True)
tcprelay.TCPRelay(config, True).start() loop = eventloop.EventLoop()
while sys.stdin.read(): tcp_server.add_to_loop(loop)
pass udp_server.add_to_loop(loop)
loop.run()
except (KeyboardInterrupt, IOError, OSError) as e: except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e) logging.error(e)
os._exit(0) os._exit(0)

View file

@ -22,11 +22,11 @@
# SOFTWARE. # SOFTWARE.
import sys import sys
import socket
import logging
import encrypt
import os import os
import logging
import utils import utils
import encrypt
import eventloop
import tcprelay import tcprelay
import udprelay import udprelay
@ -56,19 +56,17 @@ def main():
a_config['password'] = password a_config['password'] = password
logging.info("starting server at %s:%d" % logging.info("starting server at %s:%d" %
(a_config['server'], int(port))) (a_config['server'], int(port)))
tcp_server = tcprelay.TCPRelay(a_config, False) tcp_servers.append(tcprelay.TCPRelay(a_config, False))
tcp_servers.append(tcp_server) udp_servers.append(udprelay.UDPRelay(a_config, False))
udp_server = udprelay.UDPRelay(a_config, False)
udp_servers.append(udp_server)
def run_server(): def run_server():
try: try:
loop = eventloop.EventLoop()
for tcp_server in tcp_servers: for tcp_server in tcp_servers:
tcp_server.start() tcp_server.add_to_loop(loop)
for udp_server in udp_servers: for udp_server in udp_servers:
udp_server.start() udp_server.add_to_loop(loop)
while sys.stdin.read(): loop.run()
pass
except (KeyboardInterrupt, IOError, OSError) as e: except (KeyboardInterrupt, IOError, OSError) as e:
logging.error(e) logging.error(e)
os._exit(0) os._exit(0)
@ -96,10 +94,10 @@ def main():
signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGTERM, handler)
# master # master
for tcp_server in tcp_servers: for a_tcp_server in tcp_servers:
tcp_server.close() a_tcp_server.close()
for udp_server in udp_servers: for a_udp_server in udp_servers:
udp_server.close() a_udp_server.close()
for child in children: for child in children:
os.waitpid(child, 0) os.waitpid(child, 0)
@ -111,7 +109,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
try: main()
main()
except socket.error, e:
logging.error(e)

View file

@ -26,7 +26,6 @@ import socket
import logging import logging
import encrypt import encrypt
import errno import errno
import threading
import eventloop import eventloop
from common import parse_header from common import parse_header
@ -303,6 +302,7 @@ class TCPRelayHandler(object):
logging.warn('unknown socket') logging.warn('unknown socket')
def destroy(self): def destroy(self):
logging.debug('destroy')
if self._remote_sock: if self._remote_sock:
self._loop.remove(self._remote_sock) self._loop.remove(self._remote_sock)
del self._fd_to_handlers[self._remote_sock.fileno()] del self._fd_to_handlers[self._remote_sock.fileno()]
@ -320,8 +320,9 @@ class TCPRelay(object):
self._config = config self._config = config
self._is_local = is_local self._is_local = is_local
self._closed = False self._closed = False
self._thread = None self._eventloop = None
self._fd_to_handlers = {} self._fd_to_handlers = {}
self._last_time = time.time()
if is_local: if is_local:
listen_addr = config['local_address'] listen_addr = config['local_address']
@ -343,70 +344,48 @@ class TCPRelay(object):
server_socket.listen(1024) server_socket.listen(1024)
self._server_socket = server_socket self._server_socket = server_socket
def _run(self): def add_to_loop(self, loop):
server_socket = self._server_socket
self._eventloop = eventloop.EventLoop()
logging.debug('using event model: %s', self._eventloop.model)
self._eventloop.add(server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
last_time = time.time()
while not self._closed:
try:
events = self._eventloop.poll(1)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events:
if sock:
logging.debug('fd %d %s', sock.fileno(),
eventloop.EVENT_NAMES[event])
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
conn = self._server_socket.accept()
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
continue
else:
logging.error(e)
else:
if sock:
handler = self._fd_to_handlers.get(sock.fileno(), None)
if handler:
handler.handle_event(sock, event)
else:
logging.warn('can not find handler for fd %d',
sock.fileno())
self._eventloop.remove(sock)
else:
logging.warn('poll removed fd')
now = time.time()
if now - last_time > 5:
# TODO sweep timeouts
last_time = now
def start(self):
# TODO combine loops on multiple ports into one single loop
if self._closed: if self._closed:
raise Exception('closed') raise Exception('already closed')
t = threading.Thread(target=self._run) self._eventloop = loop
t.setName('TCPThread') loop.add_handler(self._handle_events)
t.setDaemon(False)
t.start() self._eventloop.add(self._server_socket,
self._thread = t eventloop.POLL_IN | eventloop.POLL_ERR)
def _handle_events(self, events):
for sock, fd, event in events:
if sock:
logging.debug('fd %d %s', fd,
eventloop.EVENT_NAMES.get(event, event))
if sock == self._server_socket:
if event & eventloop.POLL_ERR:
# TODO
raise Exception('server_socket error')
try:
logging.debug('accept')
conn = self._server_socket.accept()
TCPRelayHandler(self._fd_to_handlers, self._eventloop,
conn[0], self._config, self._is_local)
except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e)
if error_no in (errno.EAGAIN, errno.EINPROGRESS):
continue
else:
logging.error(e)
else:
if sock:
handler = self._fd_to_handlers.get(fd, None)
if handler:
handler.handle_event(sock, event)
else:
logging.warn('poll removed fd')
now = time.time()
if now - self._last_time > 5:
# TODO sweep timeouts
self._last_time = now
def close(self): def close(self):
self._closed = True self._closed = True
self._server_socket.close() self._server_socket.close()
def thread(self):
return self._thread

View file

@ -67,7 +67,6 @@
import time import time
import threading
import socket import socket
import logging import logging
import struct import struct
@ -105,8 +104,10 @@ class UDPRelay(object):
close_callback=self._close_client) close_callback=self._close_client)
self._client_fd_to_server_addr = \ self._client_fd_to_server_addr = \
lru_cache.LRUCache(timeout=config['timeout']) lru_cache.LRUCache(timeout=config['timeout'])
self._eventloop = None
self._closed = False self._closed = False
self._thread = None self._last_time = time.time()
self._sockets = set()
addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0,
socket.SOCK_DGRAM, socket.SOL_UDP) socket.SOCK_DGRAM, socket.SOL_UDP)
@ -121,6 +122,7 @@ class UDPRelay(object):
def _close_client(self, client): def _close_client(self, client):
if hasattr(client, 'close'): if hasattr(client, 'close'):
self._sockets.remove(client.fileno())
self._eventloop.remove(client) self._eventloop.remove(client)
client.close() client.close()
else: else:
@ -167,6 +169,7 @@ class UDPRelay(object):
else: else:
# drop # drop
return return
self._sockets.add(client.fileno())
self._eventloop.add(client, eventloop.POLL_IN) self._eventloop.add(client, eventloop.POLL_IN)
data = data[header_length:] data = data[header_length:]
@ -216,45 +219,29 @@ class UDPRelay(object):
# simply drop that packet # simply drop that packet
pass pass
def _run(self): def add_to_loop(self, loop):
server_socket = self._server_socket
self._eventloop = eventloop.EventLoop()
self._eventloop.add(server_socket, eventloop.POLL_IN)
last_time = time.time()
while not self._closed:
try:
events = self._eventloop.poll(10)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events:
if sock == self._server_socket:
self._handle_server()
else:
self._handle_client(sock)
now = time.time()
if now - last_time > 3.5:
self._cache.sweep()
if now - last_time > 7:
self._client_fd_to_server_addr.sweep()
last_time = now
def start(self):
if self._closed: if self._closed:
raise Exception('closed') raise Exception('already closed')
t = threading.Thread(target=self._run) self._eventloop = loop
t.setName('UDPThread') loop.add_handler(self._handle_events)
t.setDaemon(False)
t.start() server_socket = self._server_socket
self._thread = t self._eventloop.add(server_socket,
eventloop.POLL_IN | eventloop.POLL_ERR)
def _handle_events(self, events):
for sock, fd, event in events:
if sock == self._server_socket:
self._handle_server()
elif sock and (fd in self._sockets):
self._handle_client(sock)
now = time.time()
if now - self._last_time > 3.5:
self._cache.sweep()
if now - self._last_time > 7:
self._client_fd_to_server_addr.sweep()
self._last_time = now
def close(self): def close(self):
self._closed = True self._closed = True
self._server_socket.close() self._server_socket.close()
def thread(self):
return self._thread