Merge pull request #3223 from kleesc/UpdateIPRange

Create UpdateIPRange class
This commit is contained in:
Kenny Lee Sin Cheong 2018-08-31 17:09:17 -04:00 committed by GitHub
commit b39051c142
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 42 deletions

View file

@ -8,7 +8,7 @@ import boto
from app import config_provider
from storage import CloudFrontedS3Storage, StorageContext
from util.ipresolver import IPResolver
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data, test_ip_range_cache
from test.fixtures import *
_TEST_CONTENT = os.urandom(1024)
@ -21,16 +21,31 @@ _TEST_PATH = 'some/cool/path'
def ipranges_populated(request):
return request.param
@mock_s3
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app):
ipresolver = IPResolver(app)
if ipranges_populated:
@pytest.fixture()
def test_empty_ip_range_cache(empty_range_data):
sync_token = empty_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(empty_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
@pytest.fixture()
def empty_range_data():
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
return empty_range_data
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data):
@mock_s3
def test_direct_download(test_aws_ip, test_empty_ip_range_cache, test_ip_range_cache, aws_ip_range_data, ipranges_populated, app):
ipresolver = IPResolver(app)
if ipranges_populated:
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache if ipranges_populated else test_empty_ip_range_cache):
context = StorageContext('nyc', None, None, config_provider, ipresolver)
# Create a test bucket and put some test content.

View file

@ -1,8 +1,10 @@
import logging
import json
import time
from collections import namedtuple, defaultdict
from threading import Thread, Lock
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from cachetools import ttl_cache, lru_cache
@ -19,6 +21,11 @@ ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'servic
logger = logging.getLogger(__name__)
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
_UPDATE_INTERVAL = 600
_FAILED_UPDATE_RETRY_SECS = 60
CACHE = {}
CACHE_LOCK = Lock()
def update_resolver_datafiles():
""" Performs an update of the data file(s) used by the IP Resolver. """
@ -33,6 +40,20 @@ def update_resolver_datafiles():
f.write(response.text)
logger.debug('Successfully wrote %s', filename)
def _get_aws_ip_ranges():
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@add_metaclass(ABCMeta)
class IPResolverInterface(object):
@ -62,6 +83,9 @@ class IPResolver(IPResolverInterface):
def __init__(self, app):
self.app = app
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
self._worker = _UpdateIPRange(_UPDATE_INTERVAL)
if not app.config.get('TESTING', False):
self._worker.start()
@ttl_cache(maxsize=100, ttl=600)
def is_ip_possible_threat(self, ip_address):
@ -106,28 +130,23 @@ class IPResolver(IPResolverInterface):
return location_function(ip_address)
def _get_aws_ip_ranges(self):
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@ttl_cache(maxsize=1, ttl=600)
def _get_location_function(self):
aws_ip_range_json = self._get_aws_ip_ranges()
if aws_ip_range_json is None:
try:
cache = CACHE
sync_token = cache.get('sync_token', None)
if sync_token is None:
logger.debug('The aws ip range has not been cached from %s', _DATA_FILES['aws-ip-ranges.json'])
return None
all_amazon = cache['all_amazon']
regions = cache['regions']
except KeyError:
logger.exception('Got exception trying to hit aws ip range cache')
return None
except Exception:
logger.exception('Got exception trying to hit aws ip range cache')
return None
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
return IPResolver._build_location_function(sync_token, all_amazon, regions, self.geoip_db)
@staticmethod
@ -175,3 +194,30 @@ class IPResolver(IPResolverInterface):
regions[region].add(cidr)
return all_amazon, regions
class _UpdateIPRange(Thread):
"""Helper class that uses a thread to loads the IP ranges from Amazon"""
def __init__(self, interval):
Thread.__init__(self)
self.interval = interval
def run(self):
while True:
try:
logger.debug('Updating aws ip range from "%s"', 'util/ipresolver/aws-ip-ranges.json')
aws_ip_range_json = _get_aws_ip_ranges()
except:
logger.exception('Failed trying to update aws ip range')
time.sleep(_FAILED_UPDATE_RETRY_SECS)
break
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_json)
with CACHE_LOCK:
CACHE['sync_token'] = sync_token
CACHE['all_amazon'] = all_amazon
CACHE['regions'] = regions
time.sleep(self.interval)

View file

@ -2,7 +2,7 @@ import pytest
from mock import patch
from util.ipresolver import IPResolver, ResolvedLocation
from util.ipresolver import IPResolver, ResolvedLocation, CACHE
from test.fixtures import *
@pytest.fixture()
@ -23,22 +23,34 @@ def aws_ip_range_data():
}
return fake_range_doc
def test_unstarted(app, test_aws_ip):
@pytest.fixture()
def test_ip_range_cache(aws_ip_range_data):
sync_token = aws_ip_range_data['syncToken']
all_amazon, regions = IPResolver._parse_amazon_ranges(aws_ip_range_data)
fake_cache = {
'sync_token': sync_token,
'all_amazon': all_amazon,
'regions': regions,
}
return fake_cache
@pytest.fixture()
def unstarted_cache():
fake_unstarted_cache = {}
return fake_unstarted_cache
def test_unstarted(app, test_aws_ip, unstarted_cache):
with patch('util.ipresolver._UpdateIPRange'):
ipresolver = IPResolver(app)
def get_data():
return None
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
with patch.dict('util.ipresolver.CACHE', unstarted_cache):
assert ipresolver.resolve_ip(test_aws_ip) is None
def test_resolved(aws_ip_range_data, test_aws_ip, app):
def test_resolved(aws_ip_range_data, test_ip_range_cache, test_aws_ip, app):
with patch('util.ipresolver._UpdateIPRange'):
ipresolver = IPResolver(app)
def get_data():
return aws_ip_range_data
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
with patch.dict('util.ipresolver.CACHE', test_ip_range_cache):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)