Merge pull request #3223 from kleesc/UpdateIPRange
Create UpdateIPRange class
This commit is contained in:
commit
b39051c142
3 changed files with 115 additions and 42 deletions
|
@ -8,7 +8,7 @@ import boto
|
|||
from app import config_provider
|
||||
from storage import CloudFrontedS3Storage, StorageContext
|
||||
from util.ipresolver import IPResolver
|
||||
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data
|
||||
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache
|
||||
from test.fixtures import *
|
||||
|
||||
_TEST_CONTENT = os.urandom(1024)
|
||||
|
@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path'
|
|||
def ipranges_populated(request):
|
||||
return request.param
|
||||
|
||||
@mock_s3
|
||||
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app):
|
||||
ipresolver = IPResolver(app)
|
||||
if ipranges_populated:
|
||||
@pytest.fixture()
|
||||
def test_empty_ip_range_cache(empty_range_data):
|
||||
sync_token = empty_range_data['syncToken']
|
||||
all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data)
|
||||
fake_cache = {
|
||||
'sync_token': sync_token,
|
||||
'all_amazon': all_amazon,
|
||||
'regions': regions,
|
||||
}
|
||||
return fake_cache
|
||||
|
||||
@pytest.fixture()
|
||||
def empty_range_data():
|
||||
empty_range_data = {
|
||||
'syncToken': 123456789,
|
||||
'prefixes': [],
|
||||
}
|
||||
return empty_range_data
|
||||
|
||||
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data):
|
||||
@mock_s3
|
||||
def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app):
|
||||
ipresolver = IPResolver(app)
|
||||
if ipranges_populated:
|
||||
|
||||
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache):
|
||||
context = StorageContext('nyc', None, None, config_provider, ipresolver)
|
||||
|
||||
# Create a test bucket and put some test content.
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
import json
|
||||
import time
|
||||
|
||||
from collections import namedtuple, defaultdict
|
||||
|
||||
from threading import Thread, Lock
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import add_metaclass
|
||||
from cachetools import ttl_cache, lru_cache
|
||||
|
@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
|
||||
_UPDATE_INTERVAL = 600
|
||||
_FAILED_UPDATE_RETRY_SECS = 60
|
||||
|
||||
CACHE = {}
|
||||
CACHE_LOCK = Lock()
|
||||
|
||||
def update_resolver_datafiles():
|
||||
""" Performs an update of the data file(s) used by the IP Resolver. """
|
||||
|
@ -33,6 +40,20 @@ def update_resolver_datafiles():
|
|||
f.write(response.text)
|
||||
logger.debug('Successfully wrote %s', filename)
|
||||
|
||||
def _get_aws_ip_ranges():
|
||||
try:
|
||||
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
|
||||
return json.loads(f.read())
|
||||
except IOError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except ValueError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except TypeError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class IPResolverInterface(object):
|
||||
|
@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface):
|
|||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
|
||||
self._worker = _UpdateIPRange(_UPDATE_INTERVAL)
|
||||
if not app.config.get('TESTING', False):
|
||||
self._worker.start()
|
||||
|
||||
@ttl_cache(maxsize=100, ttl=600)
|
||||
def is_ip_possible_threat(self, ip_address):
|
||||
|
@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface):
|
|||
|
||||
return location_function(ip_address)
|
||||
|
||||
def _get_aws_ip_ranges(self):
|
||||
try:
|
||||
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
|
||||
return json.loads(f.read())
|
||||
except IOError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except ValueError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
except TypeError:
|
||||
logger.exception('Could not load AWS IP Ranges')
|
||||
return None
|
||||
|
||||
@ttl_cache(maxsize=1, ttl=600)
|
||||
def _get_location_function(self):
|
||||
aws_ip_range_json = self._get_aws_ip_ranges()
|
||||
if aws_ip_range_json is None:
|
||||
try:
|
||||
cache = CACHE
|
||||
sync_token = cache.get('sync_token', None)
|
||||
if sync_token is None:
|
||||
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
|
||||
return None
|
||||
|
||||
all_amazon = cache['all_amazon']
|
||||
regions = cache['regions']
|
||||
except KeyError:
|
||||
logger.exception('Got exception trying to hit aws ip range cache')
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception('Got exception trying to hit aws ip range cache')
|
||||
return None
|
||||
|
||||
sync_token = aws_ip_range_json['syncToken']
|
||||
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
|
||||
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
|
||||
|
||||
@staticmethod
|
||||
|
@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface):
|
|||
regions[region].add(cidr)
|
||||
|
||||
return all_amazon, regions
|
||||
|
||||
|
||||
class _UpdateIPRange(Thread):
|
||||
"""Helper class that uses a thread to loads the IP ranges from Amazon"""
|
||||
def __init__(self, interval):
|
||||
Thread.__init__(self)
|
||||
self.interval = interval
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
try:
|
||||
logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json')
|
||||
aws_ip_range_json = _get_aws_ip_ranges()
|
||||
except:
|
||||
logger.exception('Failed trying to update aws ip range')
|
||||
time.sleep(_FAILED_UPDATE_RETRY_SECS)
|
||||
break
|
||||
|
||||
sync_token = aws_ip_range_json['syncToken']
|
||||
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
|
||||
|
||||
with CACHE_LOCK:
|
||||
CACHE['sync_token'] = sync_token
|
||||
CACHE['all_amazon'] = all_amazon
|
||||
CACHE['regions'] = regions
|
||||
|
||||
time.sleep(self.interval)
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from mock import patch
|
||||
|
||||
from util.ipresolver import IPResolver, ResolvedLocation
|
||||
from util.ipresolver import IPResolver, ResolvedLocation, CACHE
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -23,22 +23,34 @@ def aws_ip_range_data():
|
|||
}
|
||||
return fake_range_doc
|
||||
|
||||
def test_unstarted(app, test_aws_ip):
|
||||
@pytest.fixture()
|
||||
def test_ip_range_cache(aws_ip_range_data):
|
||||
sync_token = aws_ip_range_data['syncToken']
|
||||
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data)
|
||||
fake_cache = {
|
||||
'sync_token': sync_token,
|
||||
'all_amazon': all_amazon,
|
||||
'regions': regions,
|
||||
}
|
||||
return fake_cache
|
||||
|
||||
@pytest.fixture()
|
||||
def unstarted_cache():
|
||||
fake_unstarted_cache = {}
|
||||
return fake_unstarted_cache
|
||||
|
||||
def test_unstarted(app, test_aws_ip, unstarted_cache):
|
||||
with patch('util.ipresolver._UpdateIPRange'):
|
||||
ipresolver = IPResolver(app)
|
||||
|
||||
def get_data():
|
||||
return None
|
||||
|
||||
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
|
||||
with patch.dict('util.ipresolver.CACHE', unstarted_cache):
|
||||
assert ipresolver.resolve_ip(test_aws_ip) is None
|
||||
|
||||
def test_resolved(aws_ip_range_data, test_aws_ip, app):
|
||||
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
|
||||
with patch('util.ipresolver._UpdateIPRange'):
|
||||
ipresolver = IPResolver(app)
|
||||
|
||||
def get_data():
|
||||
return aws_ip_range_data
|
||||
|
||||
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
|
||||
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
|
||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
|
||||
|
|
Reference in a new issue