more work

This commit is contained in:
clowwindy 2014-06-08 13:55:19 +08:00
parent 7a77205530
commit 6b76319495

View file

@ -23,7 +23,9 @@
import socket
import struct
import logging
import common
import eventloop
_request_count = 1
@ -66,18 +68,7 @@ QTYPE_CNAME = 5
QCLASS_IN = 1
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype == QTYPE_CNAME:
return parse_name(data, offset, length)[1]
else:
return data
def pack_address(address):
def build_address(address):
address = address.strip('.')
labels = address.split('.')
results = []
@ -91,17 +82,28 @@ def pack_address(address):
return ''.join(results)
def pack_request(address):
def build_request(address, qtype):
global _request_count
header = struct.pack('!HBBHHHH', _request_count, 1, 0, 1, 0, 0, 0)
addr = pack_address(address)
qtype_qclass = struct.pack('!HH', QTYPE_ANY, QCLASS_IN)
addr = build_address(address)
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
_request_count += 1
if _request_count > 65535:
_request_count = 1
return header + addr + qtype_qclass
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype == QTYPE_CNAME:
return parse_name(data, offset, length)[1]
else:
return data
def parse_name(data, offset, length=512):
p = offset
if (ord(data[offset]) & (128 + 64)) == (128 + 64):
@ -110,7 +112,7 @@ def parse_name(data, offset, length=512):
pointer = pointer & 0x3FFF
if pointer == offset:
return (0, None)
return (2, parse_name(data, pointer))
return (2, parse_name(data, pointer)[1])
else:
labels = []
l = ord(data[p])
@ -173,7 +175,7 @@ def parse_record(data, offset, question=False):
return len + 4, (name, None, record_type, record_class, None, None)
def unpack_response(data):
def parse_response(data):
try:
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
@ -214,39 +216,185 @@ def unpack_response(data):
offset += l
if r:
ars.append(r)
return ans
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
for an in ans:
response.answers.append((an[1], an[2], an[3]))
return response
except Exception as e:
import traceback
traceback.print_exc()
return None
def resolve(address, callback):
# TODO async
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP)
req = pack_request(address)
if req is None:
# TODO
return
sock.sendto(req, ('8.8.8.8', 53))
res, addr = sock.recvfrom(1024)
parsed_res = unpack_response(res)
callback(parsed_res)
def is_ip(address):
for family in (socket.AF_INET, socket.AF_INET6):
try:
socket.inet_pton(family, address)
return True
except (OSError, IOError):
pass
return False
class DNSResponse(object):
def __init__(self):
self.hostname = None
self.answers = [] # each: (addr, type, class)
def __str__(self):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self):
self._loop = None
self._hostname_status = {}
self._hostname_to_cb = {}
self._cb_to_hostname = {}
# TODO add caching
# TODO try ipv4 and ipv6 sequencely
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._parse_config()
def _parse_config(self):
try:
with open('/etc/resolv.conf', 'rb') as f:
servers = []
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith('nameserver'):
parts = line.split(' ')
if len(parts) >= 2:
server = parts[1]
if is_ip(server):
servers.append(server)
# TODO support more servers
if servers:
self._dns_server = (servers[0], 53)
return
except IOError:
pass
self._dns_server = ('8.8.8.8', 53)
def add_to_loop(self, loop):
self._loop = loop
loop.add(self._sock, eventloop.POLL_IN)
loop.add_handler(self.handle_events)
def _handle_data(self, data):
response = parse_response(data)
if response and response.hostname:
hostname = response.hostname
callbacks = self._hostname_to_cb.get(hostname, [])
ip = None
for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
return
for callback in callbacks:
if self._cb_to_hostname.__contains__(callback):
del self._cb_to_hostname[callback]
callback((hostname, ip), None)
if self._hostname_to_cb.__contains__(hostname):
del self._hostname_to_cb[hostname]
def handle_events(self, events):
for sock, fd, event in events:
if sock != self._sock:
continue
if event & eventloop.POLL_ERR:
logging.error('dns socket err')
self._loop.remove(self._sock)
self._sock.close()
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._loop.add(self._sock, eventloop.POLL_IN)
else:
data, addr = sock.recvfrom(1024)
if addr != self._dns_server:
logging.warn('received a packet other than our dns')
break
self._handle_data(data)
break
def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback)
if hostname:
del self._cb_to_hostname[callback]
arr = self._hostname_to_cb.get(hostname, None)
if arr:
arr.remove(callback)
if not arr:
del self._hostname_to_cb[hostname]
def _send_req(self, hostname, qtype):
logging.debug('resolving %s with type %d using server %s', hostname,
qtype, self._dns_server)
req = build_request(hostname, qtype)
self._sock.sendto(req, self._dns_server)
def resolve(self, hostname, callback):
if not hostname:
callback(None, Exception('empty hostname'))
elif is_ip(hostname):
callback(hostname, None)
else:
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
def test():
def _callback(address):
print address
logging.getLogger('').handlers = []
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filemode='a+')
resolve('www.twitter.com', _callback)
resolve('www.google.com', _callback)
resolve('ipv6.google.com', _callback)
resolve('ipv6.l.google.com', _callback)
resolve('www.baidu.com', _callback)
resolve('www.a.shifen.com', _callback)
resolve('m.baidu.jp', _callback)
def _callback(address, error):
print error, address
loop = eventloop.EventLoop()
resolver = DNSResolver()
resolver.add_to_loop(loop)
resolver.resolve('8.8.8.8', _callback)
resolver.resolve('www.twitter.com', _callback)
resolver.resolve('www.google.com', _callback)
resolver.resolve('ipv6.google.com', _callback)
resolver.resolve('ipv6.l.google.com', _callback)
resolver.resolve('www.gmail.com', _callback)
resolver.resolve('r4---sn-3qqp-ioql.googlevideo.com', _callback)
resolver.resolve('www.baidu.com', _callback)
resolver.resolve('www.a.shifen.com', _callback)
resolver.resolve('m.baidu.jp', _callback)
resolver.resolve('www.youku.com', _callback)
resolver.resolve('www.twitter.com', _callback)
resolver.resolve('ipv6.google.com', _callback)
loop.run()
if __name__ == '__main__':