From b6336393de4ae91efb4a9f6ab9e6f5044003ca7f Mon Sep 17 00:00:00 2001 From: Kenny Lee Sin Cheong Date: Thu, 23 Aug 2018 11:49:51 -0400 Subject: [PATCH] Make IPResolver run the update in a separate thread A separate thread will cache the results of parsing the range file, and the IPResolver will hit the cache instead of blocking while recomputing the ranges everytime. The thread updates every 600s, and retry every 60s on failures. --- storage/test/test_cloudfront.py | 29 ++++++--- util/ipresolver/__init__.py | 84 +++++++++++++++++++------ util/ipresolver/test/test_ipresolver.py | 44 ++++++++----- 3 files changed, 115 insertions(+), 42 deletions(-) diff --git a/storage/test/test_cloudfront.py b/storage/test/test_cloudfront.py index 99f6ca058..500e8cfbf 100644 --- a/storage/test/test_cloudfront.py +++ b/storage/test/test_cloudfront.py @@ -8,7 +8,7 @@ import boto from app import config_provider from storage import CloudFrontedS3Storage, StorageContext from util.ipresolver import IPResolver -from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data +from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache from test.fixtures import * _TEST_CONTENT = os.urandom(1024) @@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path' def ipranges_populated(request): return request.param +@pytest.fixture() +def test_empty_ip_range_cache(empty_range_data): + sync_token = empty_range_data['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data) + fake_cache = { + 'sync_token': sync_token, + 'all_amazon': all_amazon, + 'regions': regions, + } + return fake_cache + +@pytest.fixture() +def empty_range_data(): + empty_range_data = { + 'syncToken': 123456789, + 'prefixes': [], + } + return empty_range_data + @mock_s3 -def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app): +def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app): ipresolver = IPResolver(app) if ipranges_populated: - empty_range_data = { - 'syncToken': 123456789, - 'prefixes': [], - } - with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data): + with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache): context = StorageContext('nyc', None, None, config_provider, ipresolver) # Create a test bucket and put some test content. diff --git a/util/ipresolver/__init__.py b/util/ipresolver/__init__.py index 023f69e22..d57bacd18 100644 --- a/util/ipresolver/__init__.py +++ b/util/ipresolver/__init__.py @@ -1,8 +1,10 @@ import logging import json +import time from collections import namedtuple, defaultdict +from threading import Thread, Lock from abc import ABCMeta, abstractmethod from six import add_metaclass from cachetools import ttl_cache, lru_cache @@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic logger = logging.getLogger(__name__) _DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'} +_UPDATE_INTERVAL = 600 +_FAILED_UPDATE_RETRY_SECS = 60 + +CACHE = {} +CACHE_LOCK = Lock() def update_resolver_datafiles(): """ Performs an update of the data file(s) used by the IP Resolver. """ @@ -33,6 +40,20 @@ def update_resolver_datafiles(): f.write(response.text) logger.debug('Successfully wrote %s', filename) +def _get_aws_ip_ranges(): + try: + with open('util/ipresolver/aws-ip-ranges.json', 'r') as f: + return json.loads(f.read()) + except IOError: + logger.exception('Could not load AWS IP Ranges') + return None + except ValueError: + logger.exception('Could not load AWS IP Ranges') + return None + except TypeError: + logger.exception('Could not load AWS IP Ranges') + return None + @add_metaclass(ABCMeta) class IPResolverInterface(object): @@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface): def __init__(self, app): self.app = app self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb') + self._worker = _UpdateIPRange(_UPDATE_INTERVAL) + if not app.config.get('TESTING', False): + self._worker.start() @ttl_cache(maxsize=100, ttl=600) def is_ip_possible_threat(self, ip_address): @@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface): return location_function(ip_address) - def _get_aws_ip_ranges(self): - try: - with open('util/ipresolver/aws-ip-ranges.json', 'r') as f: - return json.loads(f.read()) - except IOError: - logger.exception('Could not load AWS IP Ranges') - return None - except ValueError: - logger.exception('Could not load AWS IP Ranges') - return None - except TypeError: - logger.exception('Could not load AWS IP Ranges') - return None - - @ttl_cache(maxsize=1, ttl=600) def _get_location_function(self): - aws_ip_range_json = self._get_aws_ip_ranges() - if aws_ip_range_json is None: + try: + cache = CACHE + sync_token = cache.get('sync_token', None) + if sync_token is None: + logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json']) + return None + + all_amazon = cache['all_amazon'] + regions = cache['regions'] + except KeyError: + logger.exception('Got exception trying to hit aws ip range cache') + return None + except Exception: + logger.exception('Got exception trying to hit aws ip range cache') return None - sync_token = aws_ip_range_json['syncToken'] - all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json) return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db) @staticmethod @@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface): regions[region].add(cidr) return all_amazon, regions + + +class _UpdateIPRange(Thread): + """Helper class that uses a thread to loads the IP ranges from Amazon""" + def __init__(self, interval): + Thread.__init__(self) + self.interval = interval + + def run(self): + while True: + try: + logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json') + aws_ip_range_json = _get_aws_ip_ranges() + except: + logger.exception('Failed trying to update aws ip range') + time.sleep(_FAILED_UPDATE_RETRY_SECS) + break + + sync_token = aws_ip_range_json['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json) + + with CACHE_LOCK: + CACHE['sync_token'] = sync_token + CACHE['all_amazon'] = all_amazon + CACHE['regions'] = regions + + time.sleep(self.interval) diff --git a/util/ipresolver/test/test_ipresolver.py b/util/ipresolver/test/test_ipresolver.py index d5b604a8c..de9361e8e 100644 --- a/util/ipresolver/test/test_ipresolver.py +++ b/util/ipresolver/test/test_ipresolver.py @@ -2,7 +2,7 @@ import pytest from mock import patch -from util.ipresolver import IPResolver, ResolvedLocation +from util.ipresolver import IPResolver, ResolvedLocation, CACHE from test.fixtures import * @pytest.fixture() @@ -23,23 +23,35 @@ def aws_ip_range_data(): } return fake_range_doc -def test_unstarted(app, test_aws_ip): - ipresolver = IPResolver(app) +@pytest.fixture() +def test_ip_range_cache(aws_ip_range_data): + sync_token = aws_ip_range_data['syncToken'] + all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data) + fake_cache = { + 'sync_token': sync_token, + 'all_amazon': all_amazon, + 'regions': regions, + } + return fake_cache - def get_data(): - return None +@pytest.fixture() +def unstarted_cache(): + fake_unstarted_cache = {} + return fake_unstarted_cache - with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): - assert ipresolver.resolve_ip(test_aws_ip) is None - -def test_resolved(aws_ip_range_data, test_aws_ip, app): +def test_unstarted(app, test_aws_ip, unstarted_cache): + with patch('util.ipresolver._UpdateIPRange'): ipresolver = IPResolver(app) - def get_data(): - return aws_ip_range_data + with patch.dict('util.ipresolver.CACHE', unstarted_cache): + assert ipresolver.resolve_ip(test_aws_ip) is None - with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): - assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) - assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) - assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) +def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app): + with patch('util.ipresolver._UpdateIPRange'): + ipresolver = IPResolver(app) + + with patch.dict('util.ipresolver.CACHE', test_ip_range_cache): + assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) + assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) + assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) + assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)