add tests for common.py
This commit is contained in:
parent
b5010df575
commit
70dae91e7c
2 changed files with 60 additions and 14 deletions
|
@ -48,31 +48,44 @@ chr = compat_chr
|
||||||
|
|
||||||
|
|
||||||
def to_bytes(s):
|
def to_bytes(s):
|
||||||
return s.encode('utf-8')
|
if bytes != str:
|
||||||
|
if type(s) == str:
|
||||||
|
return s.encode('utf-8')
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def to_str(s):
|
||||||
|
if bytes != str:
|
||||||
|
if type(s) == bytes:
|
||||||
|
return s.decode('utf-8')
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def inet_ntop(family, ipstr):
|
def inet_ntop(family, ipstr):
|
||||||
if family == socket.AF_INET:
|
if family == socket.AF_INET:
|
||||||
return socket.inet_ntoa(ipstr)
|
return to_bytes(socket.inet_ntoa(ipstr))
|
||||||
elif family == socket.AF_INET6:
|
elif family == socket.AF_INET6:
|
||||||
v6addr = b':'.join((b'%02X%02X' % (ord(i), ord(j)))
|
import re
|
||||||
for i, j in zip(ipstr[::2], ipstr[1::2]))
|
v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0')
|
||||||
return v6addr
|
for i, j in zip(ipstr[::2], ipstr[1::2]))
|
||||||
|
v6addr = re.sub('::+', '::', v6addr, count=1)
|
||||||
|
return to_bytes(v6addr)
|
||||||
|
|
||||||
|
|
||||||
def inet_pton(family, addr):
|
def inet_pton(family, addr):
|
||||||
|
addr = to_str(addr)
|
||||||
if family == socket.AF_INET:
|
if family == socket.AF_INET:
|
||||||
return socket.inet_aton(addr)
|
return socket.inet_aton(addr)
|
||||||
elif family == socket.AF_INET6:
|
elif family == socket.AF_INET6:
|
||||||
if b'.' in addr: # a v4 addr
|
if '.' in addr: # a v4 addr
|
||||||
v4addr = addr[addr.rindex(b':') + 1:]
|
v4addr = addr[addr.rindex(':') + 1:]
|
||||||
v4addr = socket.inet_aton(v4addr)
|
v4addr = socket.inet_aton(v4addr)
|
||||||
v4addr = map(lambda x: (b'%02X' % ord(x)), v4addr)
|
v4addr = map(lambda x: ('%02X' % ord(x)), v4addr)
|
||||||
v4addr.insert(2, b':')
|
v4addr.insert(2, ':')
|
||||||
newaddr = addr[:addr.rindex(b':') + 1] + b''.join(v4addr)
|
newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr)
|
||||||
return inet_pton(family, newaddr)
|
return inet_pton(family, newaddr)
|
||||||
dbyts = [0] * 8 # 8 groups
|
dbyts = [0] * 8 # 8 groups
|
||||||
grps = addr.split(b':')
|
grps = addr.split(':')
|
||||||
for i, v in enumerate(grps):
|
for i, v in enumerate(grps):
|
||||||
if v:
|
if v:
|
||||||
dbyts[i] = int(v, 16)
|
dbyts[i] = int(v, 16)
|
||||||
|
@ -105,9 +118,10 @@ ADDRTYPE_HOST = 3
|
||||||
|
|
||||||
|
|
||||||
def pack_addr(address):
|
def pack_addr(address):
|
||||||
|
address_str = to_str(address)
|
||||||
for family in (socket.AF_INET, socket.AF_INET6):
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
try:
|
try:
|
||||||
r = socket.inet_pton(family, address)
|
r = socket.inet_pton(family, address_str)
|
||||||
if family == socket.AF_INET6:
|
if family == socket.AF_INET6:
|
||||||
return b'\x04' + r
|
return b'\x04' + r
|
||||||
else:
|
else:
|
||||||
|
@ -155,4 +169,36 @@ def parse_header(data):
|
||||||
addrtype)
|
addrtype)
|
||||||
if dest_addr is None:
|
if dest_addr is None:
|
||||||
return None
|
return None
|
||||||
return addrtype, dest_addr, dest_port, header_length
|
return addrtype, to_bytes(dest_addr), dest_port, header_length
|
||||||
|
|
||||||
|
|
||||||
|
def test_inet_conv():
|
||||||
|
ipv4 = b'8.8.4.4'
|
||||||
|
b = inet_pton(socket.AF_INET, ipv4)
|
||||||
|
assert inet_ntop(socket.AF_INET, b) == ipv4
|
||||||
|
ipv6 = b'2404:6800:4005:805::1011'
|
||||||
|
b = inet_pton(socket.AF_INET6, ipv6)
|
||||||
|
assert inet_ntop(socket.AF_INET6, b) == ipv6
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_header():
|
||||||
|
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
|
||||||
|
(3, b'www.google.com', 80, 18)
|
||||||
|
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
|
||||||
|
(1, b'8.8.8.8', 53, 7)
|
||||||
|
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
|
||||||
|
b'\x00\x10\x11\x00\x50')) == \
|
||||||
|
(4, b'2404:6800:4005:805::1011', 80, 19)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pack_header():
|
||||||
|
assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08'
|
||||||
|
assert pack_addr(b'2404:6800:4005:805::1011') == \
|
||||||
|
b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11'
|
||||||
|
assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_inet_conv()
|
||||||
|
test_parse_header()
|
||||||
|
test_pack_header()
|
||||||
|
|
|
@ -261,7 +261,7 @@ class TCPRelayHandler(object):
|
||||||
if header_result is None:
|
if header_result is None:
|
||||||
raise Exception('can not parse header')
|
raise Exception('can not parse header')
|
||||||
addrtype, remote_addr, remote_port, header_length = header_result
|
addrtype, remote_addr, remote_port, header_length = header_result
|
||||||
logging.info('connecting %s:%d' % (remote_addr.decode('utf-8'),
|
logging.info('connecting %s:%d' % (common.to_str(remote_addr),
|
||||||
remote_port))
|
remote_port))
|
||||||
self._remote_address = (remote_addr, remote_port)
|
self._remote_address = (remote_addr, remote_port)
|
||||||
# pause reading
|
# pause reading
|
||||||
|
|
Loading…
Reference in a new issue