more work

This commit is contained in:
clowwindy 2014-06-08 14:17:26 +08:00
parent 6b76319495
commit 45f9998fa9
2 changed files with 39 additions and 51 deletions

View file

@ -28,6 +28,8 @@ import common
import eventloop
common.patch_socket()
_request_count = 1
# rfc1035
@ -99,40 +101,30 @@ def parse_ip(addrtype, data, length, offset):
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype == QTYPE_CNAME:
return parse_name(data, offset, length)[1]
return parse_name(data, offset)[1]
else:
return data
def parse_name(data, offset, length=512):
def parse_name(data, offset):
p = offset
if (ord(data[offset]) & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[offset:offset + 2])[0]
pointer = pointer & 0x3FFF
if pointer == offset:
return (0, None)
return (2, parse_name(data, pointer)[1])
else:
labels = []
l = ord(data[p])
while l > 0 and p < offset + length:
while l > 0:
if (l & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer = pointer & 0x3FFF
# if pointer == offset:
# return (0, None)
pointer &= 0x3FFF
r = parse_name(data, pointer)
labels.append(r[1])
p += 2
# pointer is the end
return (p - offset + 1, '.'.join(labels))
return p - offset, '.'.join(labels)
else:
labels.append(data[p + 1:p + 1 + l])
p += 1 + l
l = ord(data[p])
return (p - offset + 1, '.'.join(labels))
return p - offset + 1, '.'.join(labels)
# rfc1035
@ -158,33 +150,30 @@ def parse_name(data, offset, length=512):
# / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):
len, name = parse_name(data, offset)
# TODO
assert len
nlen, name = parse_name(data, offset)
if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
'!HHiH', data[offset + len:offset + len + 10]
'!HHiH', data[offset + nlen:offset + nlen + 10]
)
ip = parse_ip(record_type, data, record_rdlength, offset + len + 10)
return len + 10 + record_rdlength, \
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \
(name, ip, record_type, record_class, record_ttl)
else:
record_type, record_class = struct.unpack(
'!HH', data[offset + len:offset + len + 4]
'!HH', data[offset + nlen:offset + nlen + 4]
)
return len + 4, (name, None, record_type, record_class, None, None)
return nlen + 4, (name, None, record_type, record_class, None, None)
def parse_response(data):
try:
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0]
res_qr = header[1] & 128
# res_id = header[0]
# res_qr = header[1] & 128
res_tc = header[1] & 2
res_ra = header[2] & 128
# res_ra = header[2] & 128
res_rcode = header[2] & 15
# TODO check tc and rcode
assert res_tc == 0
assert res_rcode == 0
res_qdcount = header[3]
@ -193,8 +182,6 @@ def parse_response(data):
res_arcount = header[6]
qds = []
ans = []
nss = []
ars = []
offset = 12
for i in xrange(0, res_qdcount):
l, r = parse_record(data, offset, True)
@ -209,13 +196,9 @@ def parse_response(data):
for i in xrange(0, res_nscount):
l, r = parse_record(data, offset)
offset += l
if r:
nss.append(r)
for i in xrange(0, res_arcount):
l, r = parse_record(data, offset)
offset += l
if r:
ars.append(r)
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
@ -225,6 +208,7 @@ def parse_response(data):
except Exception as e:
import traceback
traceback.print_exc()
logging.error(e)
return None
@ -380,9 +364,9 @@ def test():
resolver = DNSResolver()
resolver.add_to_loop(loop)
resolver.resolve('www.google.com', _callback)
resolver.resolve('8.8.8.8', _callback)
resolver.resolve('www.twitter.com', _callback)
resolver.resolve('www.google.com', _callback)
resolver.resolve('ipv6.google.com', _callback)
resolver.resolve('ipv6.l.google.com', _callback)
resolver.resolve('www.gmail.com', _callback)

View file

@ -63,13 +63,17 @@ def inet_pton(family, addr):
raise RuntimeError("What family?")
if not hasattr(socket, 'inet_pton'):
def patch_socket():
if not hasattr(socket, 'inet_pton'):
socket.inet_pton = inet_pton
if not hasattr(socket, 'inet_ntop'):
if not hasattr(socket, 'inet_ntop'):
socket.inet_ntop = inet_ntop
patch_socket()
ADDRTYPE_IPV4 = 1
ADDRTYPE_IPV6 = 4
ADDRTYPE_HOST = 3