more work
This commit is contained in:
parent
7a77205530
commit
6b76319495
1 changed files with 188 additions and 40 deletions
|
@ -23,7 +23,9 @@
|
|||
|
||||
import socket
|
||||
import struct
|
||||
import logging
|
||||
import common
|
||||
import eventloop
|
||||
|
||||
|
||||
_request_count = 1
|
||||
|
@ -66,18 +68,7 @@ QTYPE_CNAME = 5
|
|||
QCLASS_IN = 1
|
||||
|
||||
|
||||
def parse_ip(addrtype, data, length, offset):
|
||||
if addrtype == QTYPE_A:
|
||||
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
|
||||
elif addrtype == QTYPE_AAAA:
|
||||
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
|
||||
elif addrtype == QTYPE_CNAME:
|
||||
return parse_name(data, offset, length)[1]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def pack_address(address):
|
||||
def build_address(address):
|
||||
address = address.strip('.')
|
||||
labels = address.split('.')
|
||||
results = []
|
||||
|
@ -91,17 +82,28 @@ def pack_address(address):
|
|||
return ''.join(results)
|
||||
|
||||
|
||||
def pack_request(address):
|
||||
def build_request(address, qtype):
|
||||
global _request_count
|
||||
header = struct.pack('!HBBHHHH', _request_count, 1, 0, 1, 0, 0, 0)
|
||||
addr = pack_address(address)
|
||||
qtype_qclass = struct.pack('!HH', QTYPE_ANY, QCLASS_IN)
|
||||
addr = build_address(address)
|
||||
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
|
||||
_request_count += 1
|
||||
if _request_count > 65535:
|
||||
_request_count = 1
|
||||
return header + addr + qtype_qclass
|
||||
|
||||
|
||||
def parse_ip(addrtype, data, length, offset):
|
||||
if addrtype == QTYPE_A:
|
||||
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
|
||||
elif addrtype == QTYPE_AAAA:
|
||||
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
|
||||
elif addrtype == QTYPE_CNAME:
|
||||
return parse_name(data, offset, length)[1]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def parse_name(data, offset, length=512):
|
||||
p = offset
|
||||
if (ord(data[offset]) & (128 + 64)) == (128 + 64):
|
||||
|
@ -110,7 +112,7 @@ def parse_name(data, offset, length=512):
|
|||
pointer = pointer & 0x3FFF
|
||||
if pointer == offset:
|
||||
return (0, None)
|
||||
return (2, parse_name(data, pointer))
|
||||
return (2, parse_name(data, pointer)[1])
|
||||
else:
|
||||
labels = []
|
||||
l = ord(data[p])
|
||||
|
@ -173,7 +175,7 @@ def parse_record(data, offset, question=False):
|
|||
return len + 4, (name, None, record_type, record_class, None, None)
|
||||
|
||||
|
||||
def unpack_response(data):
|
||||
def parse_response(data):
|
||||
try:
|
||||
if len(data) >= 12:
|
||||
header = struct.unpack('!HBBHHHH', data[:12])
|
||||
|
@ -214,39 +216,185 @@ def unpack_response(data):
|
|||
offset += l
|
||||
if r:
|
||||
ars.append(r)
|
||||
|
||||
return ans
|
||||
|
||||
response = DNSResponse()
|
||||
if qds:
|
||||
response.hostname = qds[0][0]
|
||||
for an in ans:
|
||||
response.answers.append((an[1], an[2], an[3]))
|
||||
return response
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def resolve(address, callback):
|
||||
# TODO async
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP)
|
||||
req = pack_request(address)
|
||||
if req is None:
|
||||
# TODO
|
||||
return
|
||||
sock.sendto(req, ('8.8.8.8', 53))
|
||||
res, addr = sock.recvfrom(1024)
|
||||
parsed_res = unpack_response(res)
|
||||
callback(parsed_res)
|
||||
def is_ip(address):
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
socket.inet_pton(family, address)
|
||||
return True
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class DNSResponse(object):
|
||||
def __init__(self):
|
||||
self.hostname = None
|
||||
self.answers = [] # each: (addr, type, class)
|
||||
|
||||
def __str__(self):
|
||||
return '%s: %s' % (self.hostname, str(self.answers))
|
||||
|
||||
|
||||
STATUS_IPV4 = 0
|
||||
STATUS_IPV6 = 1
|
||||
|
||||
|
||||
class DNSResolver(object):
|
||||
|
||||
def __init__(self):
|
||||
self._loop = None
|
||||
self._hostname_status = {}
|
||||
self._hostname_to_cb = {}
|
||||
self._cb_to_hostname = {}
|
||||
# TODO add caching
|
||||
# TODO try ipv4 and ipv6 sequencely
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||
socket.SOL_UDP)
|
||||
self._sock.setblocking(False)
|
||||
self._parse_config()
|
||||
|
||||
def _parse_config(self):
|
||||
try:
|
||||
with open('/etc/resolv.conf', 'rb') as f:
|
||||
servers = []
|
||||
content = f.readlines()
|
||||
for line in content:
|
||||
line = line.strip()
|
||||
if line:
|
||||
if line.startswith('nameserver'):
|
||||
parts = line.split(' ')
|
||||
if len(parts) >= 2:
|
||||
server = parts[1]
|
||||
if is_ip(server):
|
||||
servers.append(server)
|
||||
# TODO support more servers
|
||||
if servers:
|
||||
self._dns_server = (servers[0], 53)
|
||||
return
|
||||
except IOError:
|
||||
pass
|
||||
self._dns_server = ('8.8.8.8', 53)
|
||||
|
||||
def add_to_loop(self, loop):
|
||||
self._loop = loop
|
||||
loop.add(self._sock, eventloop.POLL_IN)
|
||||
loop.add_handler(self.handle_events)
|
||||
|
||||
def _handle_data(self, data):
|
||||
response = parse_response(data)
|
||||
if response and response.hostname:
|
||||
hostname = response.hostname
|
||||
callbacks = self._hostname_to_cb.get(hostname, [])
|
||||
ip = None
|
||||
for answer in response.answers:
|
||||
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
|
||||
answer[2] == QCLASS_IN:
|
||||
ip = answer[0]
|
||||
break
|
||||
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
|
||||
== STATUS_IPV4:
|
||||
self._hostname_status[hostname] = STATUS_IPV6
|
||||
self._send_req(hostname, QTYPE_AAAA)
|
||||
return
|
||||
for callback in callbacks:
|
||||
if self._cb_to_hostname.__contains__(callback):
|
||||
del self._cb_to_hostname[callback]
|
||||
callback((hostname, ip), None)
|
||||
if self._hostname_to_cb.__contains__(hostname):
|
||||
del self._hostname_to_cb[hostname]
|
||||
|
||||
def handle_events(self, events):
|
||||
for sock, fd, event in events:
|
||||
if sock != self._sock:
|
||||
continue
|
||||
if event & eventloop.POLL_ERR:
|
||||
logging.error('dns socket err')
|
||||
self._loop.remove(self._sock)
|
||||
self._sock.close()
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||
socket.SOL_UDP)
|
||||
self._sock.setblocking(False)
|
||||
self._loop.add(self._sock, eventloop.POLL_IN)
|
||||
else:
|
||||
data, addr = sock.recvfrom(1024)
|
||||
if addr != self._dns_server:
|
||||
logging.warn('received a packet other than our dns')
|
||||
break
|
||||
self._handle_data(data)
|
||||
break
|
||||
|
||||
def remove_callback(self, callback):
|
||||
hostname = self._cb_to_hostname.get(callback)
|
||||
if hostname:
|
||||
del self._cb_to_hostname[callback]
|
||||
arr = self._hostname_to_cb.get(hostname, None)
|
||||
if arr:
|
||||
arr.remove(callback)
|
||||
if not arr:
|
||||
del self._hostname_to_cb[hostname]
|
||||
|
||||
def _send_req(self, hostname, qtype):
|
||||
logging.debug('resolving %s with type %d using server %s', hostname,
|
||||
qtype, self._dns_server)
|
||||
req = build_request(hostname, qtype)
|
||||
self._sock.sendto(req, self._dns_server)
|
||||
|
||||
def resolve(self, hostname, callback):
|
||||
if not hostname:
|
||||
callback(None, Exception('empty hostname'))
|
||||
elif is_ip(hostname):
|
||||
callback(hostname, None)
|
||||
else:
|
||||
arr = self._hostname_to_cb.get(hostname, None)
|
||||
if not arr:
|
||||
self._hostname_status[hostname] = STATUS_IPV4
|
||||
self._send_req(hostname, QTYPE_A)
|
||||
self._hostname_to_cb[hostname] = [callback]
|
||||
self._cb_to_hostname[callback] = hostname
|
||||
else:
|
||||
arr.append(callback)
|
||||
|
||||
|
||||
def test():
|
||||
def _callback(address):
|
||||
print address
|
||||
logging.getLogger('').handlers = []
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S', filemode='a+')
|
||||
|
||||
resolve('www.twitter.com', _callback)
|
||||
resolve('www.google.com', _callback)
|
||||
resolve('ipv6.google.com', _callback)
|
||||
resolve('ipv6.l.google.com', _callback)
|
||||
resolve('www.baidu.com', _callback)
|
||||
resolve('www.a.shifen.com', _callback)
|
||||
resolve('m.baidu.jp', _callback)
|
||||
def _callback(address, error):
|
||||
print error, address
|
||||
|
||||
loop = eventloop.EventLoop()
|
||||
resolver = DNSResolver()
|
||||
resolver.add_to_loop(loop)
|
||||
|
||||
resolver.resolve('8.8.8.8', _callback)
|
||||
resolver.resolve('www.twitter.com', _callback)
|
||||
resolver.resolve('www.google.com', _callback)
|
||||
resolver.resolve('ipv6.google.com', _callback)
|
||||
resolver.resolve('ipv6.l.google.com', _callback)
|
||||
resolver.resolve('www.gmail.com', _callback)
|
||||
resolver.resolve('r4---sn-3qqp-ioql.googlevideo.com', _callback)
|
||||
resolver.resolve('www.baidu.com', _callback)
|
||||
resolver.resolve('www.a.shifen.com', _callback)
|
||||
resolver.resolve('m.baidu.jp', _callback)
|
||||
resolver.resolve('www.youku.com', _callback)
|
||||
resolver.resolve('www.twitter.com', _callback)
|
||||
resolver.resolve('ipv6.google.com', _callback)
|
||||
|
||||
loop.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Add table
Reference in a new issue