Make IPResolver run the update in a separate thread
A separate thread will cache the results of parsing the range file, and the IPResolver will hit the cache instead of blocking while recomputing the ranges everytime. The thread updates every 600s, and retry every 60s on failures.
This commit is contained in:
parent
975a3bfe3b
commit
b6336393de
3 changed files with 115 additions and 42 deletions
|
@ -8,7 +8,7 @@ import boto
|
||||||
from app import config_provider
|
from app import config_provider
|
||||||
from storage import CloudFrontedS3Storage, StorageContext
|
from storage import CloudFrontedS3Storage, StorageContext
|
||||||
from util.ipresolver import IPResolver
|
from util.ipresolver import IPResolver
|
||||||
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data
|
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
_TEST_CONTENT = os.urandom(1024)
|
_TEST_CONTENT = os.urandom(1024)
|
||||||
|
@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path'
|
||||||
def ipranges_populated(request):
|
def ipranges_populated(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def test_empty_ip_range_cache(empty_range_data):
|
||||||
|
sync_token = empty_range_data['syncToken']
|
||||||
|
all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data)
|
||||||
|
fake_cache = {
|
||||||
|
'sync_token': sync_token,
|
||||||
|
'all_amazon': all_amazon,
|
||||||
|
'regions': regions,
|
||||||
|
}
|
||||||
|
return fake_cache
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def empty_range_data():
|
||||||
|
empty_range_data = {
|
||||||
|
'syncToken': 123456789,
|
||||||
|
'prefixes': [],
|
||||||
|
}
|
||||||
|
return empty_range_data
|
||||||
|
|
||||||
@mock_s3
|
@mock_s3
|
||||||
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app):
|
def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app):
|
||||||
ipresolver = IPResolver(app)
|
ipresolver = IPResolver(app)
|
||||||
if ipranges_populated:
|
if ipranges_populated:
|
||||||
empty_range_data = {
|
|
||||||
'syncToken': 123456789,
|
|
||||||
'prefixes': [],
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data):
|
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache):
|
||||||
context = StorageContext('nyc', None, None, config_provider, ipresolver)
|
context = StorageContext('nyc', None, None, config_provider, ipresolver)
|
||||||
|
|
||||||
# Create a test bucket and put some test content.
|
# Create a test bucket and put some test content.
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
from collections import namedtuple, defaultdict
|
from collections import namedtuple, defaultdict
|
||||||
|
|
||||||
|
from threading import Thread, Lock
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from six import add_metaclass
|
from six import add_metaclass
|
||||||
from cachetools import ttl_cache, lru_cache
|
from cachetools import ttl_cache, lru_cache
|
||||||
|
@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
|
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
|
||||||
|
_UPDATE_INTERVAL = 600
|
||||||
|
_FAILED_UPDATE_RETRY_SECS = 60
|
||||||
|
|
||||||
|
CACHE = {}
|
||||||
|
CACHE_LOCK = Lock()
|
||||||
|
|
||||||
def update_resolver_datafiles():
|
def update_resolver_datafiles():
|
||||||
""" Performs an update of the data file(s) used by the IP Resolver. """
|
""" Performs an update of the data file(s) used by the IP Resolver. """
|
||||||
|
@ -33,6 +40,20 @@ def update_resolver_datafiles():
|
||||||
f.write(response.text)
|
f.write(response.text)
|
||||||
logger.debug('Successfully wrote %s', filename)
|
logger.debug('Successfully wrote %s', filename)
|
||||||
|
|
||||||
|
def _get_aws_ip_ranges():
|
||||||
|
try:
|
||||||
|
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
|
||||||
|
return json.loads(f.read())
|
||||||
|
except IOError:
|
||||||
|
logger.exception('Could not load AWS IP Ranges')
|
||||||
|
return None
|
||||||
|
except ValueError:
|
||||||
|
logger.exception('Could not load AWS IP Ranges')
|
||||||
|
return None
|
||||||
|
except TypeError:
|
||||||
|
logger.exception('Could not load AWS IP Ranges')
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@add_metaclass(ABCMeta)
|
@add_metaclass(ABCMeta)
|
||||||
class IPResolverInterface(object):
|
class IPResolverInterface(object):
|
||||||
|
@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface):
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
|
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
|
||||||
|
self._worker = _UpdateIPRange(_UPDATE_INTERVAL)
|
||||||
|
if not app.config.get('TESTING', False):
|
||||||
|
self._worker.start()
|
||||||
|
|
||||||
@ttl_cache(maxsize=100, ttl=600)
|
@ttl_cache(maxsize=100, ttl=600)
|
||||||
def is_ip_possible_threat(self, ip_address):
|
def is_ip_possible_threat(self, ip_address):
|
||||||
|
@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface):
|
||||||
|
|
||||||
return location_function(ip_address)
|
return location_function(ip_address)
|
||||||
|
|
||||||
def _get_aws_ip_ranges(self):
|
|
||||||
try:
|
|
||||||
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
|
|
||||||
return json.loads(f.read())
|
|
||||||
except IOError:
|
|
||||||
logger.exception('Could not load AWS IP Ranges')
|
|
||||||
return None
|
|
||||||
except ValueError:
|
|
||||||
logger.exception('Could not load AWS IP Ranges')
|
|
||||||
return None
|
|
||||||
except TypeError:
|
|
||||||
logger.exception('Could not load AWS IP Ranges')
|
|
||||||
return None
|
|
||||||
|
|
||||||
@ttl_cache(maxsize=1, ttl=600)
|
|
||||||
def _get_location_function(self):
|
def _get_location_function(self):
|
||||||
aws_ip_range_json = self._get_aws_ip_ranges()
|
try:
|
||||||
if aws_ip_range_json is None:
|
cache = CACHE
|
||||||
|
sync_token = cache.get('sync_token', None)
|
||||||
|
if sync_token is None:
|
||||||
|
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_amazon = cache['all_amazon']
|
||||||
|
regions = cache['regions']
|
||||||
|
except KeyError:
|
||||||
|
logger.exception('Got exception trying to hit aws ip range cache')
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception('Got exception trying to hit aws ip range cache')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sync_token = aws_ip_range_json['syncToken']
|
|
||||||
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
|
|
||||||
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
|
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface):
|
||||||
regions[region].add(cidr)
|
regions[region].add(cidr)
|
||||||
|
|
||||||
return all_amazon, regions
|
return all_amazon, regions
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateIPRange(Thread):
|
||||||
|
"""Helper class that uses a thread to loads the IP ranges from Amazon"""
|
||||||
|
def __init__(self, interval):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.interval = interval
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json')
|
||||||
|
aws_ip_range_json = _get_aws_ip_ranges()
|
||||||
|
except:
|
||||||
|
logger.exception('Failed trying to update aws ip range')
|
||||||
|
time.sleep(_FAILED_UPDATE_RETRY_SECS)
|
||||||
|
break
|
||||||
|
|
||||||
|
sync_token = aws_ip_range_json['syncToken']
|
||||||
|
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
|
||||||
|
|
||||||
|
with CACHE_LOCK:
|
||||||
|
CACHE['sync_token'] = sync_token
|
||||||
|
CACHE['all_amazon'] = all_amazon
|
||||||
|
CACHE['regions'] = regions
|
||||||
|
|
||||||
|
time.sleep(self.interval)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
|
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
from util.ipresolver import IPResolver, ResolvedLocation
|
from util.ipresolver import IPResolver, ResolvedLocation, CACHE
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
@ -23,23 +23,35 @@ def aws_ip_range_data():
|
||||||
}
|
}
|
||||||
return fake_range_doc
|
return fake_range_doc
|
||||||
|
|
||||||
def test_unstarted(app, test_aws_ip):
|
@pytest.fixture()
|
||||||
ipresolver = IPResolver(app)
|
def test_ip_range_cache(aws_ip_range_data):
|
||||||
|
sync_token = aws_ip_range_data['syncToken']
|
||||||
|
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data)
|
||||||
|
fake_cache = {
|
||||||
|
'sync_token': sync_token,
|
||||||
|
'all_amazon': all_amazon,
|
||||||
|
'regions': regions,
|
||||||
|
}
|
||||||
|
return fake_cache
|
||||||
|
|
||||||
def get_data():
|
@pytest.fixture()
|
||||||
return None
|
def unstarted_cache():
|
||||||
|
fake_unstarted_cache = {}
|
||||||
|
return fake_unstarted_cache
|
||||||
|
|
||||||
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
|
def test_unstarted(app, test_aws_ip, unstarted_cache):
|
||||||
assert ipresolver.resolve_ip(test_aws_ip) is None
|
with patch('util.ipresolver._UpdateIPRange'):
|
||||||
|
|
||||||
def test_resolved(aws_ip_range_data, test_aws_ip, app):
|
|
||||||
ipresolver = IPResolver(app)
|
ipresolver = IPResolver(app)
|
||||||
|
|
||||||
def get_data():
|
with patch.dict('util.ipresolver.CACHE', unstarted_cache):
|
||||||
return aws_ip_range_data
|
assert ipresolver.resolve_ip(test_aws_ip) is None
|
||||||
|
|
||||||
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
|
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
|
||||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
with patch('util.ipresolver._UpdateIPRange'):
|
||||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
ipresolver = IPResolver(app)
|
||||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
|
|
||||||
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)
|
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
|
||||||
|
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||||
|
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||||
|
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
|
||||||
|
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)
|
||||||
|
|
Reference in a new issue