fix remove()

This commit is contained in:
clowwindy 2014-04-24 12:34:31 +08:00
parent 6834178a89
commit e6a225513e
3 changed files with 62 additions and 24 deletions

View file

@ -29,15 +29,15 @@ import select
from collections import defaultdict
__all__ = ['EventLoop', 'MODE_NULL', 'MODE_IN', 'MODE_OUT', 'MODE_ERR',
'MODE_HUP', 'MODE_NVAL']
__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR',
'POLL_HUP', 'POLL_NVAL']
MODE_NULL = 0x00
MODE_IN = 0x01
MODE_OUT = 0x04
MODE_ERR = 0x08
MODE_HUP = 0x10
MODE_NVAL = 0x20
POLL_NULL = 0x00
POLL_IN = 0x01
POLL_OUT = 0x04
POLL_ERR = 0x08
POLL_HUP = 0x10
POLL_NVAL = 0x20
class EpollLoop(object):
@ -68,9 +68,9 @@ class KqueueLoop(object):
def _control(self, fd, mode, flags):
events = []
if mode & MODE_IN:
if mode & POLL_IN:
events.append(select.kevent(fd, select.KQ_FILTER_READ, flags))
if mode & MODE_OUT:
if mode & POLL_OUT:
events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags))
for e in events:
self._kqueue.control([e], 0)
@ -79,13 +79,13 @@ class KqueueLoop(object):
if timeout < 0:
timeout = None # kqueue behaviour
events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout)
results = defaultdict(lambda: MODE_NULL)
results = defaultdict(lambda: POLL_NULL)
for e in events:
fd = e.ident
if e.filter == select.KQ_FILTER_READ:
results[fd] |= MODE_IN
results[fd] |= POLL_IN
elif e.filter == select.KQ_FILTER_WRITE:
results[fd] |= MODE_OUT
results[fd] |= POLL_OUT
return results.iteritems()
def add_fd(self, fd, mode):
@ -111,18 +111,18 @@ class SelectLoop(object):
def poll(self, timeout):
r, w, x = select.select(self._r_list, self._w_list, self._x_list,
timeout)
results = defaultdict(lambda: MODE_NULL)
for p in [(r, MODE_IN), (w, MODE_OUT), (x, MODE_ERR)]:
results = defaultdict(lambda: POLL_NULL)
for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]:
for fd in p[0]:
results[fd] |= p[1]
return results.items()
def add_fd(self, fd, mode):
if mode & MODE_IN:
if mode & POLL_IN:
self._r_list.add(fd)
if mode & MODE_OUT:
if mode & POLL_OUT:
self._w_list.add(fd)
if mode & MODE_ERR:
if mode & POLL_ERR:
self._x_list.add(fd)
def remove_fd(self, fd):
@ -168,3 +168,22 @@ class EventLoop(object):
def modify(self, f, mode):
fd = f.fileno()
self._impl.modify_fd(fd, mode)
# from tornado
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None

View file

@ -10,8 +10,9 @@ import time
class LRUCache(collections.MutableMapping):
"""This class is not thread safe"""
def __init__(self, timeout=60, *args, **kwargs):
def __init__(self, timeout=60, close_callback=None, *args, **kwargs):
self.timeout = timeout
self.close_callback = close_callback
self.store = {}
self.time_to_keys = collections.defaultdict(list)
self.last_visits = []
@ -53,8 +54,9 @@ class LRUCache(collections.MutableMapping):
heapq.heappop(self.last_visits)
if self.store.__contains__(key):
value = self.store[key]
if hasattr(value, 'close'):
value.close()
if self.close_callback is not None:
self.close_callback(value)
del self.store[key]
c += 1
del self.time_to_keys[least]

View file

@ -74,6 +74,7 @@ import struct
import encrypt
import eventloop
import lru_cache
import errno
BUF_SIZE = 65536
@ -137,6 +138,14 @@ class UDPRelay(object):
self._cache = lru_cache.LRUCache(timeout=timeout)
self._client_fd_to_server_addr = lru_cache.LRUCache(timeout=timeout)
def _close_client(self, client):
if hasattr(client, 'close'):
self._eventloop.remove(client)
client.close()
else:
# just an address
pass
def _handle_server(self):
server = self._server_socket
data, r_addr = server.recvfrom(BUF_SIZE)
@ -177,7 +186,7 @@ class UDPRelay(object):
else:
# drop
return
self._eventloop.add(client, eventloop.MODE_IN)
self._eventloop.add(client, eventloop.POLL_IN)
# prevent from recv other sources
if self._is_local:
@ -225,10 +234,18 @@ class UDPRelay(object):
def _run(self):
server_socket = self._server_socket
self._eventloop.add(server_socket, eventloop.MODE_IN)
self._eventloop.add(server_socket, eventloop.POLL_IN)
last_time = time.time()
while True:
events = self._eventloop.poll(10)
try:
events = self._eventloop.poll(10)
except (OSError, IOError) as e:
if eventloop.errno_from_exception(e) == errno.EPIPE:
# Happens when the client closes the connection
continue
else:
logging.error(e)
continue
for sock, event in events:
if sock == self._server_socket:
self._handle_server()