From fbc4906445860d9ee84adec8947303634ba06442 Mon Sep 17 00:00:00 2001 From: clowwindy Date: Sun, 8 Jun 2014 15:41:39 +0800 Subject: [PATCH] add cache --- shadowsocks/asyncdns.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 259511b..7c52b16 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -25,6 +25,7 @@ import socket import struct import logging import common +import lru_cache import eventloop @@ -242,8 +243,7 @@ class DNSResolver(object): self._hostname_status = {} self._hostname_to_cb = {} self._cb_to_hostname = {} - # TODO add caching - # TODO try ipv4 and ipv6 sequencely + self._cache = lru_cache.LRUCache(timeout=300) self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP) self._sock.setblocking(False) @@ -276,11 +276,21 @@ class DNSResolver(object): loop.add(self._sock, eventloop.POLL_IN) loop.add_handler(self.handle_events) + def _call_callback(self, hostname, ip): + callbacks = self._hostname_to_cb.get(hostname, []) + for callback in callbacks: + if self._cb_to_hostname.__contains__(callback): + del self._cb_to_hostname[callback] + callback((hostname, ip), None) + if self._hostname_to_cb.__contains__(hostname): + del self._hostname_to_cb[hostname] + if self._hostname_status.__contains__(hostname): + del self._hostname_status[hostname] + def _handle_data(self, data): response = parse_response(data) if response and response.hostname: hostname = response.hostname - callbacks = self._hostname_to_cb.get(hostname, []) ip = None for answer in response.answers: if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ @@ -291,15 +301,9 @@ class DNSResolver(object): == STATUS_IPV4: self._hostname_status[hostname] = STATUS_IPV6 self._send_req(hostname, QTYPE_AAAA) - return - for callback in callbacks: - if self._cb_to_hostname.__contains__(callback): - del self._cb_to_hostname[callback] - callback((hostname, ip), None) - if self._hostname_to_cb.__contains__(hostname): - del self._hostname_to_cb[hostname] - if self._hostname_status.__contains__(hostname): - del self._hostname_status[hostname] + else: + self._cache[hostname] = ip + self._call_callback(hostname, ip) def handle_events(self, events): for sock, fd, event in events: @@ -344,6 +348,10 @@ class DNSResolver(object): callback(None, Exception('empty hostname')) elif is_ip(hostname): callback(hostname, None) + elif self._cache.__contains__(hostname): + logging.debug('hit cache: %s', hostname) + ip = self._cache[hostname] + callback(ip, None) else: arr = self._hostname_to_cb.get(hostname, None) if not arr: