Make IPResolver run the update in a separate thread

A separate thread will cache the results of parsing the range
file, and the IPResolver will hit the cache instead of blocking while recomputing the
ranges everytime. The thread updates every 600s, and retry every 60s on
failures.
This commit is contained in:
Kenny Lee Sin Cheong 2018-08-23 11:49:51 -04:00 committed by Kenny Lee Sin Cheong
parent 975a3bfe3b
commit b6336393de
3 changed files with 115 additions and 42 deletions

View file

@ -8,7 +8,7 @@ import boto
from app import config_provider
from storage import CloudFrontedS3Storage, StorageContext
from util.ipresolver import IPResolver
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache
from test.fixtures import *
_TEST_CONTENT = os.urandom(1024)
@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path'
def ipranges_populated(request):
return request.param
@pytest.fixture()
def test_empty_ip_range_cache(empty_range_data):
sync_token = empty_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
@pytest.fixture()
def empty_range_data():
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
return empty_range_data
@mock_s3
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app):
def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app):
ipresolver = IPResolver(app)
if ipranges_populated:
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data):
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache):
context = StorageContext('nyc', None, None, config_provider, ipresolver)
# Create a test bucket and put some test content.

View file

@ -1,8 +1,10 @@
import logging
import json
import time
from collections import namedtuple, defaultdict
from threading import Thread, Lock
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from cachetools import ttl_cache, lru_cache
@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic
logger = logging.getLogger(__name__)
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
_UPDATE_INTERVAL = 600
_FAILED_UPDATE_RETRY_SECS = 60
CACHE = {}
CACHE_LOCK = Lock()
def update_resolver_datafiles():
""" Performs an update of the data file(s) used by the IP Resolver. """
@ -33,6 +40,20 @@ def update_resolver_datafiles():
f.write(response.text)
logger.debug('Successfully wrote %s', filename)
def _get_aws_ip_ranges():
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@add_metaclass(ABCMeta)
class IPResolverInterface(object):
@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface):
def __init__(self, app):
self.app = app
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
self._worker = _UpdateIPRange(_UPDATE_INTERVAL)
if not app.config.get('TESTING', False):
self._worker.start()
@ttl_cache(maxsize=100, ttl=600)
def is_ip_possible_threat(self, ip_address):
@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface):
return location_function(ip_address)
def _get_aws_ip_ranges(self):
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@ttl_cache(maxsize=1, ttl=600)
def _get_location_function(self):
aws_ip_range_json = self._get_aws_ip_ranges()
if aws_ip_range_json is None:
try:
cache = CACHE
sync_token = cache.get('sync_token', None)
if sync_token is None:
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
return None
all_amazon = cache['all_amazon']
regions = cache['regions']
except KeyError:
logger.exception('Got exception trying to hit aws ip range cache')
return None
except Exception:
logger.exception('Got exception trying to hit aws ip range cache')
return None
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
@staticmethod
@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface):
regions[region].add(cidr)
return all_amazon, regions
class _UpdateIPRange(Thread):
"""Helper class that uses a thread to loads the IP ranges from Amazon"""
def __init__(self, interval):
Thread.__init__(self)
self.interval = interval
def run(self):
while True:
try:
logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json')
aws_ip_range_json = _get_aws_ip_ranges()
except:
logger.exception('Failed trying to update aws ip range')
time.sleep(_FAILED_UPDATE_RETRY_SECS)
break
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
with CACHE_LOCK:
CACHE['sync_token'] = sync_token
CACHE['all_amazon'] = all_amazon
CACHE['regions'] = regions
time.sleep(self.interval)

View file

@ -2,7 +2,7 @@ import pytest
from mock import patch
from util.ipresolver import IPResolver, ResolvedLocation
from util.ipresolver import IPResolver, ResolvedLocation, CACHE
from test.fixtures import *
@pytest.fixture()
@ -23,23 +23,35 @@ def aws_ip_range_data():
}
return fake_range_doc
def test_unstarted(app, test_aws_ip):
ipresolver = IPResolver(app)
@pytest.fixture()
def test_ip_range_cache(aws_ip_range_data):
sync_token = aws_ip_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
def get_data():
return None
@pytest.fixture()
def unstarted_cache():
fake_unstarted_cache = {}
return fake_unstarted_cache
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
assert ipresolver.resolve_ip(test_aws_ip) is None
def test_resolved(aws_ip_range_data, test_aws_ip, app):
def test_unstarted(app, test_aws_ip, unstarted_cache):
with patch('util.ipresolver._UpdateIPRange'):
ipresolver = IPResolver(app)
def get_data():
return aws_ip_range_data
with patch.dict('util.ipresolver.CACHE', unstarted_cache):
assert ipresolver.resolve_ip(test_aws_ip) is None
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
with patch('util.ipresolver._UpdateIPRange'):
ipresolver = IPResolver(app)
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)