From 100ebcf064f4a8e5840012bdccef106794208569 Mon Sep 17 00:00:00 2001 From: Sunny Date: Sat, 31 Jan 2015 19:50:10 +0800 Subject: [PATCH] Add IPNetwork class to support CIDR calculation Usage: Use IPNetwork(str|list) to create an IPNetwork object. Use operator 'in' to determine whether the specified IP address is in the IP network or not, like: >>> '192.168.1.1' in IPNetwork('192.168.1.0/24') True Both IPv4 and IPv6 address are supported. Note: When using string to initialize the IPNetwork, a comma seperated IP network list should be provided. Currently, IPNetwork just support standard CIDR like: x.x.x.x/y eg. 192.168.1.0/24 ::x/y eg. ::1/10 If pure IP address was provided, it will be treated as implicit IP network, like 192.168.0.0 will be treated as 192.168.0.0/16 and 192.168.1.1 will be treated as 192.168.1.1/32 This implicit translate may cause some unexpected behavior, like user provide 192.168.2.0 and expect it will be treated as 192.168.2.0/24 but actually it will be translated to 192.168.2.0/23 because there are 9 continuous 0 from right. In order to avoid confusion, a warning message will be displayed when pure IP address was provided. Other variants of CIDR are not supported yet. --- shadowsocks/common.py | 67 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/shadowsocks/common.py b/shadowsocks/common.py index 2be17e5..d582923 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -184,6 +184,57 @@ def parse_header(data): return addrtype, to_bytes(dest_addr), dest_port, header_length +class IPNetwork(object): + ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128} + + def __init__(self, addrs): + self._network_list_v4 = [] + self._network_list_v6 = [] + if type(addrs) == str: + addrs = addrs.split(',') + map(self.add_network, addrs) + + def add_network(self, addr): + block = addr.split('/') + addr_family = is_ip(block[0]) + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(block[0])) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0])) + ip = (hi << 64) | lo + else: + raise SyntaxError("Not a valid CIDR notation: %s" % addr) + if len(block) is 1: + prefix_size = 0 + while ((ip & 1) == 0): + ip >>= 1 + prefix_size += 1 + logging.warn("You did't specify CIDR routing prefix size for %s, " + "implicit treated as %s/%d" % (addr, addr, + IPNetwork.ADDRLENGTH[addr_family] - prefix_size)) + elif block[1].isdigit() and int(block[1]) <= IPNetwork.ADDRLENGTH[addr_family]: + prefix_size = IPNetwork.ADDRLENGTH[addr_family] - int(block[1]) + ip >>= prefix_size + else: + raise SyntaxError("Not a valid CIDR notation: %s" % addr) + if addr_family is socket.AF_INET: + self._network_list_v4.append((ip, prefix_size)) + else: + self._network_list_v6.append((ip, prefix_size)) + + def __contains__(self, addr): + addr_family = is_ip(addr) + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(addr)) + return any(map(lambda (naddr, ps): naddr == ip >> ps, self._network_list_v4)) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr)) + ip = (hi << 64) | lo + return any(map(lambda (naddr, ps): naddr == ip >> ps, self._network_list_v6)) + else: + return False + + def test_inet_conv(): ipv4 = b'8.8.4.4' b = inet_pton(socket.AF_INET, ipv4) @@ -210,7 +261,23 @@ def test_pack_header(): assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com' +def test_ip_network(): + ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.168.2.0') + assert '127.0.0.1' in ip_network + assert '127.0.1.1' not in ip_network + assert ':ff:ffff' in ip_network + assert '::ffff:1' not in ip_network + assert '::1' in ip_network + assert '::2' not in ip_network + assert '192.168.1.1' in ip_network + assert '192.168.1.2' not in ip_network + assert '192.168.2.1' in ip_network + assert '192.168.3.1' in ip_network # 192.168.2.0 is treated as 192.168.2.0/23 + assert 'www.google.com' not in ip_network + + if __name__ == '__main__': test_inet_conv() test_parse_header() test_pack_header() + test_ip_network()