more work

This commit is contained in:
clowwindy 2014-06-08 14:17:26 +08:00
parent 6b76319495
commit 45f9998fa9
2 changed files with 39 additions and 51 deletions

View file

@ -28,6 +28,8 @@ import common
import eventloop import eventloop
common.patch_socket()
_request_count = 1 _request_count = 1
# rfc1035 # rfc1035
@ -99,40 +101,30 @@ def parse_ip(addrtype, data, length, offset):
elif addrtype == QTYPE_AAAA: elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype == QTYPE_CNAME: elif addrtype == QTYPE_CNAME:
return parse_name(data, offset, length)[1] return parse_name(data, offset)[1]
else: else:
return data return data
def parse_name(data, offset, length=512): def parse_name(data, offset):
p = offset p = offset
if (ord(data[offset]) & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[offset:offset + 2])[0]
pointer = pointer & 0x3FFF
if pointer == offset:
return (0, None)
return (2, parse_name(data, pointer)[1])
else:
labels = [] labels = []
l = ord(data[p]) l = ord(data[p])
while l > 0 and p < offset + length: while l > 0:
if (l & (128 + 64)) == (128 + 64): if (l & (128 + 64)) == (128 + 64):
# pointer # pointer
pointer = struct.unpack('!H', data[p:p + 2])[0] pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer = pointer & 0x3FFF pointer &= 0x3FFF
# if pointer == offset:
# return (0, None)
r = parse_name(data, pointer) r = parse_name(data, pointer)
labels.append(r[1]) labels.append(r[1])
p += 2 p += 2
# pointer is the end # pointer is the end
return (p - offset + 1, '.'.join(labels)) return p - offset, '.'.join(labels)
else: else:
labels.append(data[p + 1:p + 1 + l]) labels.append(data[p + 1:p + 1 + l])
p += 1 + l p += 1 + l
l = ord(data[p]) l = ord(data[p])
return (p - offset + 1, '.'.join(labels)) return p - offset + 1, '.'.join(labels)
# rfc1035 # rfc1035
@ -158,33 +150,30 @@ def parse_name(data, offset, length=512):
# / / # / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False): def parse_record(data, offset, question=False):
len, name = parse_name(data, offset) nlen, name = parse_name(data, offset)
# TODO
assert len
if not question: if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack( record_type, record_class, record_ttl, record_rdlength = struct.unpack(
'!HHiH', data[offset + len:offset + len + 10] '!HHiH', data[offset + nlen:offset + nlen + 10]
) )
ip = parse_ip(record_type, data, record_rdlength, offset + len + 10) ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return len + 10 + record_rdlength, \ return nlen + 10 + record_rdlength, \
(name, ip, record_type, record_class, record_ttl) (name, ip, record_type, record_class, record_ttl)
else: else:
record_type, record_class = struct.unpack( record_type, record_class = struct.unpack(
'!HH', data[offset + len:offset + len + 4] '!HH', data[offset + nlen:offset + nlen + 4]
) )
return len + 4, (name, None, record_type, record_class, None, None) return nlen + 4, (name, None, record_type, record_class, None, None)
def parse_response(data): def parse_response(data):
try: try:
if len(data) >= 12: if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12]) header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0] # res_id = header[0]
res_qr = header[1] & 128 # res_qr = header[1] & 128
res_tc = header[1] & 2 res_tc = header[1] & 2
res_ra = header[2] & 128 # res_ra = header[2] & 128
res_rcode = header[2] & 15 res_rcode = header[2] & 15
# TODO check tc and rcode
assert res_tc == 0 assert res_tc == 0
assert res_rcode == 0 assert res_rcode == 0
res_qdcount = header[3] res_qdcount = header[3]
@ -193,8 +182,6 @@ def parse_response(data):
res_arcount = header[6] res_arcount = header[6]
qds = [] qds = []
ans = [] ans = []
nss = []
ars = []
offset = 12 offset = 12
for i in xrange(0, res_qdcount): for i in xrange(0, res_qdcount):
l, r = parse_record(data, offset, True) l, r = parse_record(data, offset, True)
@ -209,13 +196,9 @@ def parse_response(data):
for i in xrange(0, res_nscount): for i in xrange(0, res_nscount):
l, r = parse_record(data, offset) l, r = parse_record(data, offset)
offset += l offset += l
if r:
nss.append(r)
for i in xrange(0, res_arcount): for i in xrange(0, res_arcount):
l, r = parse_record(data, offset) l, r = parse_record(data, offset)
offset += l offset += l
if r:
ars.append(r)
response = DNSResponse() response = DNSResponse()
if qds: if qds:
response.hostname = qds[0][0] response.hostname = qds[0][0]
@ -225,6 +208,7 @@ def parse_response(data):
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
logging.error(e)
return None return None
@ -380,9 +364,9 @@ def test():
resolver = DNSResolver() resolver = DNSResolver()
resolver.add_to_loop(loop) resolver.add_to_loop(loop)
resolver.resolve('www.google.com', _callback)
resolver.resolve('8.8.8.8', _callback) resolver.resolve('8.8.8.8', _callback)
resolver.resolve('www.twitter.com', _callback) resolver.resolve('www.twitter.com', _callback)
resolver.resolve('www.google.com', _callback)
resolver.resolve('ipv6.google.com', _callback) resolver.resolve('ipv6.google.com', _callback)
resolver.resolve('ipv6.l.google.com', _callback) resolver.resolve('ipv6.l.google.com', _callback)
resolver.resolve('www.gmail.com', _callback) resolver.resolve('www.gmail.com', _callback)

View file

@ -63,13 +63,17 @@ def inet_pton(family, addr):
raise RuntimeError("What family?") raise RuntimeError("What family?")
if not hasattr(socket, 'inet_pton'): def patch_socket():
if not hasattr(socket, 'inet_pton'):
socket.inet_pton = inet_pton socket.inet_pton = inet_pton
if not hasattr(socket, 'inet_ntop'): if not hasattr(socket, 'inet_ntop'):
socket.inet_ntop = inet_ntop socket.inet_ntop = inet_ntop
patch_socket()
ADDRTYPE_IPV4 = 1 ADDRTYPE_IPV4 = 1
ADDRTYPE_IPV6 = 4 ADDRTYPE_IPV6 = 4
ADDRTYPE_HOST = 3 ADDRTYPE_HOST = 3