Make IPResolver run the update in a separate thread

A separate thread will cache the results of parsing the range
file, and the IPResolver will hit the cache instead of blocking while recomputing the
ranges everytime. The thread updates every 600s, and retry every 60s on
failures.
This commit is contained in:
Kenny Lee Sin Cheong 2018-08-23 11:49:51 -04:00 committed by Kenny Lee Sin Cheong
parent 975a3bfe3b
commit b6336393de
3 changed files with 115 additions and 42 deletions

View file

@ -8,7 +8,7 @@ import boto
from app import config_provider from app import config_provider
from storage import CloudFrontedS3Storage, StorageContext from storage import CloudFrontedS3Storage, StorageContext
from util.ipresolver import IPResolver from util.ipresolver import IPResolver
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache
from test.fixtures import * from test.fixtures import *
_TEST_CONTENT = os.urandom(1024) _TEST_CONTENT = os.urandom(1024)
@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path'
def ipranges_populated(request): def ipranges_populated(request):
return request.param return request.param
@pytest.fixture()
def test_empty_ip_range_cache(empty_range_data):
sync_token = empty_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
@pytest.fixture()
def empty_range_data():
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
return empty_range_data
@mock_s3 @mock_s3
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app): def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app):
ipresolver = IPResolver(app) ipresolver = IPResolver(app)
if ipranges_populated: if ipranges_populated:
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data): with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache):
context = StorageContext('nyc', None, None, config_provider, ipresolver) context = StorageContext('nyc', None, None, config_provider, ipresolver)
# Create a test bucket and put some test content. # Create a test bucket and put some test content.

View file

@ -1,8 +1,10 @@
import logging import logging
import json import json
import time
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from threading import Thread, Lock
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six import add_metaclass from six import add_metaclass
from cachetools import ttl_cache, lru_cache from cachetools import ttl_cache, lru_cache
@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'} _DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
_UPDATE_INTERVAL = 600
_FAILED_UPDATE_RETRY_SECS = 60
CACHE = {}
CACHE_LOCK = Lock()
def update_resolver_datafiles(): def update_resolver_datafiles():
""" Performs an update of the data file(s) used by the IP Resolver. """ """ Performs an update of the data file(s) used by the IP Resolver. """
@ -33,6 +40,20 @@ def update_resolver_datafiles():
f.write(response.text) f.write(response.text)
logger.debug('Successfully wrote %s', filename) logger.debug('Successfully wrote %s', filename)
def _get_aws_ip_ranges():
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@add_metaclass(ABCMeta) @add_metaclass(ABCMeta)
class IPResolverInterface(object): class IPResolverInterface(object):
@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface):
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb') self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
self._worker = _UpdateIPRange(_UPDATE_INTERVAL)
if not app.config.get('TESTING', False):
self._worker.start()
@ttl_cache(maxsize=100, ttl=600) @ttl_cache(maxsize=100, ttl=600)
def is_ip_possible_threat(self, ip_address): def is_ip_possible_threat(self, ip_address):
@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface):
return location_function(ip_address) return location_function(ip_address)
def _get_aws_ip_ranges(self):
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@ttl_cache(maxsize=1, ttl=600)
def _get_location_function(self): def _get_location_function(self):
aws_ip_range_json = self._get_aws_ip_ranges() try:
if aws_ip_range_json is None: cache = CACHE
sync_token = cache.get('sync_token', None)
if sync_token is None:
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
return None
all_amazon = cache['all_amazon']
regions = cache['regions']
except KeyError:
logger.exception('Got exception trying to hit aws ip range cache')
return None
except Exception:
logger.exception('Got exception trying to hit aws ip range cache')
return None return None
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db) return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
@staticmethod @staticmethod
@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface):
regions[region].add(cidr) regions[region].add(cidr)
return all_amazon, regions return all_amazon, regions
class _UpdateIPRange(Thread):
"""Helper class that uses a thread to loads the IP ranges from Amazon"""
def __init__(self, interval):
Thread.__init__(self)
self.interval = interval
def run(self):
while True:
try:
logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json')
aws_ip_range_json = _get_aws_ip_ranges()
except:
logger.exception('Failed trying to update aws ip range')
time.sleep(_FAILED_UPDATE_RETRY_SECS)
break
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
with CACHE_LOCK:
CACHE['sync_token'] = sync_token
CACHE['all_amazon'] = all_amazon
CACHE['regions'] = regions
time.sleep(self.interval)

View file

@ -2,7 +2,7 @@ import pytest
from mock import patch from mock import patch
from util.ipresolver import IPResolver, ResolvedLocation from util.ipresolver import IPResolver, ResolvedLocation, CACHE
from test.fixtures import * from test.fixtures import *
@pytest.fixture() @pytest.fixture()
@ -23,23 +23,35 @@ def aws_ip_range_data():
} }
return fake_range_doc return fake_range_doc
def test_unstarted(app, test_aws_ip): @pytest.fixture()
ipresolver = IPResolver(app) def test_ip_range_cache(aws_ip_range_data):
sync_token = aws_ip_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
def get_data(): @pytest.fixture()
return None def unstarted_cache():
fake_unstarted_cache = {}
return fake_unstarted_cache
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): def test_unstarted(app, test_aws_ip, unstarted_cache):
assert ipresolver.resolve_ip(test_aws_ip) is None with patch('util.ipresolver._UpdateIPRange'):
def test_resolved(aws_ip_range_data, test_aws_ip, app):
ipresolver = IPResolver(app) ipresolver = IPResolver(app)
def get_data(): with patch.dict('util.ipresolver.CACHE', unstarted_cache):
return aws_ip_range_data assert ipresolver.resolve_ip(test_aws_ip) is None
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data): def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) with patch('util.ipresolver._UpdateIPRange'):
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) ipresolver = IPResolver(app)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)