diff --git a/storage/test/test_cloudfront.py b/storage/test/test_cloudfront.py index 99f6ca058..500e8cfbf 100644 --- a/storage/test/test_cloudfront.py +++ b/storage/test/test_cloudfront.py @@ -8,7 +8,7 @@ import boto from app import config_provider from storage import CloudFrontedS3Storage, StorageContext from util.ipresolver import IPResolver -from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data +from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache from test.fixtures import * _TEST_CONTENT = os.urandom(1024) @@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path' def ipranges_populated(request): return request.param +@pytest.fixture() +def test_empty_ip_range_cache(empty_range_data): + sync_token = empty_range_data['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data) + fake_cache = { + 'sync_token': sync_token, + 'all_amazon': all_amazon, + 'regions': regions, + } + return fake_cache + +@pytest.fixture() +def empty_range_data(): + empty_range_data = { + 'syncToken': 123456789, + 'prefixes': [], + } + return empty_range_data + @mock_s3 -def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app): +def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app): ipresolver = IPResolver(app) if ipranges_populated: - empty_range_data = { - 'syncToken': 123456789, - 'prefixes': [], - } - with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data): + with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache): context = StorageContext('nyc', None, None, config_provider, ipresolver) # Create a test bucket and put some test content. diff --git a/util/ipresolver/__init__.py b/util/ipresolver/__init__.py index 023f69e22..d57bacd18 100644 --- a/util/ipresolver/__init__.py +++ b/util/ipresolver/__init__.py @@ -1,8 +1,10 @@ import logging import json +import time from collections import namedtuple, defaultdict +from threading import Thread, Lock from abc import ABCMeta, abstractmethod from six import add_metaclass from cachetools import ttl_cache, lru_cache @@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic logger = logging.getLogger(__name__) _DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'} +_UPDATE_INTERVAL = 600 +_FAILED_UPDATE_RETRY_SECS = 60 + +CACHE = {} +CACHE_LOCK = Lock() def update_resolver_datafiles(): """ Performs an update of the data file(s) used by the IP Resolver. """ @@ -33,6 +40,20 @@ def update_resolver_datafiles(): f.write(response.text) logger.debug('Successfully wrote %s', filename) +def _get_aws_ip_ranges(): + try: + with open('util/ipresolver/aws-ip-ranges.json', 'r') as f: + return json.loads(f.read()) + except IOError: + logger.exception('Could not load AWS IP Ranges') + return None + except ValueError: + logger.exception('Could not load AWS IP Ranges') + return None + except TypeError: + logger.exception('Could not load AWS IP Ranges') + return None + @add_metaclass(ABCMeta) class IPResolverInterface(object): @@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface): def __init__(self, app): self.app = app self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb') + self._worker = _UpdateIPRange(_UPDATE_INTERVAL) + if not app.config.get('TESTING', False): + self._worker.start() @ttl_cache(maxsize=100, ttl=600) def is_ip_possible_threat(self, ip_address): @@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface): return location_function(ip_address) - def _get_aws_ip_ranges(self): - try: - with open('util/ipresolver/aws-ip-ranges.json', 'r') as f: - return json.loads(f.read()) - except IOError: - logger.exception('Could not load AWS IP Ranges') - return None - except ValueError: - logger.exception('Could not load AWS IP Ranges') - return None - except TypeError: - logger.exception('Could not load AWS IP Ranges') - return None - - @ttl_cache(maxsize=1, ttl=600) def _get_location_function(self): - aws_ip_range_json = self._get_aws_ip_ranges() - if aws_ip_range_json is None: + try: + cache = CACHE + sync_token = cache.get('sync_token', None) + if sync_token is None: + logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json']) + return None + + all_amazon = cache['all_amazon'] + regions = cache['regions'] + except KeyError: + logger.exception('Got exception trying to hit aws ip range cache') + return None + except Exception: + logger.exception('Got exception trying to hit aws ip range cache') return None - sync_token = aws_ip_range_json['syncToken'] - all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json) return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db) @staticmethod @@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface): regions[region].add(cidr) return all_amazon, regions + + +class _UpdateIPRange(Thread): + """Helper class that uses a thread to loads the IP ranges from Amazon""" + def __init__(self, interval): + Thread.__init__(self) + self.interval = interval + + def run(self): + while True: + try: + logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json') + aws_ip_range_json = _get_aws_ip_ranges() + except: + logger.exception('Failed trying to update aws ip range') + time.sleep(_FAILED_UPDATE_RETRY_SECS) + break + + sync_token = aws_ip_range_json['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json) + + with CACHE_LOCK: + CACHE['sync_token'] = sync_token + CACHE['all_amazon'] = all_amazon + CACHE['regions'] = regions + + time.sleep(self.interval) diff --git a/util/ipresolver/test/test_ipresolver.py b/util/ipresolver/test/test_ipresolver.py index d5b604a8c..de9361e8e 100644 --- a/util/ipresolver/test/test_ipresolver.py +++ b/util/ipresolver/test/test_ipresolver.py @@ -2,7 +2,7 @@ import pytest from mock import patch -from util.ipresolver import IPResolver, ResolvedLocation +from util.ipresolver import IPResolver, ResolvedLocation, CACHE from test.fixtures import * @pytest.fixture() @@ -23,23 +23,35 @@ def aws_ip_range_data(): } return fake_range_doc -def test_unstarted(app, test_aws_ip): - ipresolver = IPResolver(app) +@pytest.fixture() +def test_ip_range_cache(aws_ip_range_data): + sync_token = aws_ip_range_data['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data) + fake_cache = { + 'sync_token': sync_token, + 'all_amazon': all_amazon, + 'regions': regions, + } + return fake_cache - def get_data(): - return None +@pytest.fixture() +def unstarted_cache(): + fake_unstarted_cache = {} + return fake_unstarted_cache - with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): - assert ipresolver.resolve_ip(test_aws_ip) is None - -def test_resolved(aws_ip_range_data, test_aws_ip, app): +def test_unstarted(app, test_aws_ip, unstarted_cache): + with patch('util.ipresolver._UpdateIPRange'): ipresolver = IPResolver(app) - def get_data(): - return aws_ip_range_data + with patch.dict('util.ipresolver.CACHE', unstarted_cache): + assert ipresolver.resolve_ip(test_aws_ip) is None - with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): - assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) - assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) +def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app): + with patch('util.ipresolver._UpdateIPRange'): + ipresolver = IPResolver(app) + + with patch.dict('util.ipresolver.CACHE', test_ip_range_cache): + assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) + assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) + assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) + assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)