Merge pull request #2869 from coreos-inc/joseph.schorr/QS-2/cloudfront

CloudFront redirect support
This commit is contained in:
josephschorr 2017-09-29 12:08:50 -04:00 committed by GitHub
commit 82e09d6f16
27 changed files with 388 additions and 29 deletions

View file

@ -19,7 +19,7 @@ RUN virtualenv --distribute venv \
&& venv/bin/pip freeze && venv/bin/pip freeze
# Install front-end dependencies # Install front-end dependencies
# JS depedencies # JS dependencies
COPY yarn.lock package.json tsconfig.json webpack.config.js tslint.json ./ COPY yarn.lock package.json tsconfig.json webpack.config.js tslint.json ./
RUN yarn install --ignore-engines RUN yarn install --ignore-engines
@ -31,6 +31,9 @@ RUN yarn build \
COPY . . COPY . .
# Update local copy of AWS IP Ranges.
RUN curl https://ip-ranges.amazonaws.com/ip-ranges.json -o util/ipresolver/aws-ip-ranges.json
# Set up the init system # Set up the init system
RUN mkdir -p /etc/my_init.d /etc/systlog-ng /usr/local/bin /etc/monit static/fonts static/ldn /usr/local/nginx/logs/ \ RUN mkdir -p /etc/my_init.d /etc/systlog-ng /usr/local/bin /etc/monit static/fonts static/ldn /usr/local/nginx/logs/ \
&& cp $QUAYCONF/init/*.sh /etc/my_init.d/ \ && cp $QUAYCONF/init/*.sh /etc/my_init.d/ \

4
app.py
View file

@ -35,6 +35,7 @@ from oauth.loginmanager import OAuthLoginManager
from storage import Storage from storage import Storage
from util.log import filter_logs from util.log import filter_logs
from util import get_app_url from util import get_app_url
from util.ipresolver import IPResolver
from util.saas.analytics import Analytics from util.saas.analytics import Analytics
from util.saas.useranalytics import UserAnalytics from util.saas.useranalytics import UserAnalytics
from util.saas.exceptionlog import Sentry from util.saas.exceptionlog import Sentry
@ -195,7 +196,8 @@ prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus) metric_queue = MetricQueue(prometheus)
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue) chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue)
instance_keys = InstanceKeys(app) instance_keys = InstanceKeys(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys) ip_resolver = IPResolver(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver)
userfiles = Userfiles(app, storage) userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage) log_archive = LogArchive(app, storage)
analytics = Analytics(app) analytics = Analytics(app)

View file

@ -0,0 +1,7 @@
#!/bin/sh
# Ensure dependencies start before the logger
sv check syslog-ng > /dev/null || exit 1
# Start the logger
exec logger -i -t ipresolverupdateworker

View file

@ -0,0 +1,9 @@
#! /bin/bash
echo 'Starting ip resolver update worker'
QUAYPATH=${QUAYPATH:-"."}
cd ${QUAYDIR:-"/"}
PYTHONPATH=$QUAYPATH venv/bin/python -m workers.ipresolverupdateworker 2>&1
echo 'IP resolver update worker exited'

View file

@ -107,7 +107,8 @@ class DelegateUserfiles(object):
def get_file_url(self, file_id, expires_in=300, requires_cors=False): def get_file_url(self, file_id, expires_in=300, requires_cors=False):
path = self.get_file_id_path(file_id) path = self.get_file_id_path(file_id)
url = self._storage.get_direct_download_url(self._locations, path, expires_in, requires_cors) url = self._storage.get_direct_download_url(self._locations, path, request.remote_addr, expires_in,
requires_cors)
if url is None: if url is None:
if self._handler_name is None: if self._handler_name is None:

View file

@ -6,6 +6,7 @@ from cnr.models.channel_base import ChannelBase
from cnr.models.db_base import CnrDB from cnr.models.db_base import CnrDB
from cnr.models.package_base import PackageBase, manifest_media_type from cnr.models.package_base import PackageBase, manifest_media_type
from flask import request
from app import storage from app import storage
from endpoints.appr.models_oci import model from endpoints.appr.models_oci import model
@ -36,7 +37,7 @@ class Blob(BlobBase):
locations = model.get_blob_locations(digest) locations = model.get_blob_locations(digest)
if not locations: if not locations:
raise_package_not_found(package_name, digest) raise_package_not_found(package_name, digest)
return storage.get_direct_download_url(locations, blobpath) return storage.get_direct_download_url(locations, blobpath, request.remote_addr)
class Channel(ChannelBase): class Channel(ChannelBase):

View file

@ -132,7 +132,7 @@ def get_image_layer(namespace, repository, image_id, headers):
abort(404, 'Image %(image_id)s not found', issue='unknown-image', image_id=image_id) abort(404, 'Image %(image_id)s not found', issue='unknown-image', image_id=image_id)
try: try:
logger.debug('Looking up the direct download URL for path: %s', path) logger.debug('Looking up the direct download URL for path: %s', path)
direct_download_url = store.get_direct_download_url(locations, path) direct_download_url = store.get_direct_download_url(locations, path, request.remote_addr)
if direct_download_url: if direct_download_url:
logger.debug('Returning direct download URL') logger.debug('Returning direct download URL')
resp = redirect(direct_download_url) resp = redirect(direct_download_url)

View file

@ -83,7 +83,7 @@ def download_blob(namespace_name, repo_name, digest):
# Short-circuit by redirecting if the storage supports it. # Short-circuit by redirecting if the storage supports it.
logger.debug('Looking up the direct download URL for path: %s', path) logger.debug('Looking up the direct download URL for path: %s', path)
direct_download_url = storage.get_direct_download_url(blob.locations, path) direct_download_url = storage.get_direct_download_url(blob.locations, path, request.remote_addr)
if direct_download_url: if direct_download_url:
logger.debug('Returning direct download URL') logger.debug('Returning direct download URL')
resp = redirect(direct_download_url) resp = redirect(direct_download_url)

View file

@ -23,10 +23,12 @@ bencode
bintrees bintrees
bitmath bitmath
boto boto
boto3
cachetools==1.1.6 cachetools==1.1.6
cryptography cryptography
flask flask
flask-restful flask-restful
geoip2
gevent gevent
gipc gipc
gunicorn<19.0 gunicorn<19.0
@ -40,6 +42,7 @@ mixpanel
mock mock
moto==0.4.25 # remove when 0.4.28+ is out moto==0.4.25 # remove when 0.4.28+ is out
namedlist namedlist
netaddr
pathvalidate pathvalidate
peewee==2.8.1 peewee==2.8.1
psutil psutil

View file

@ -23,6 +23,7 @@ bintrees==2.0.6
bitmath==1.3.1.2 bitmath==1.3.1.2
blinker==1.4 blinker==1.4
boto==2.46.1 boto==2.46.1
boto3==1.4.7
cachetools==1.1.6 cachetools==1.1.6
certifi==2017.4.17 certifi==2017.4.17
cffi==1.10.0 cffi==1.10.0
@ -43,6 +44,7 @@ functools32==3.2.3.post2
furl==1.0.0 furl==1.0.0
future==0.16.0 future==0.16.0
futures==3.0.5 futures==3.0.5
geoip2==2.5.0
gevent==1.2.1 gevent==1.2.1
gipc==0.6.0 gipc==0.6.0
greenlet==0.4.12 greenlet==0.4.12

View file

@ -1,5 +1,5 @@
from storage.local import LocalStorage from storage.local import LocalStorage
from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage, CloudFrontedS3Storage
from storage.fakestorage import FakeStorage from storage.fakestorage import FakeStorage
from storage.distributedstorage import DistributedStorage from storage.distributedstorage import DistributedStorage
from storage.swift import SwiftStorage from storage.swift import SwiftStorage
@ -11,39 +11,44 @@ STORAGE_DRIVER_CLASSES = {
'GoogleCloudStorage': GoogleCloudStorage, 'GoogleCloudStorage': GoogleCloudStorage,
'RadosGWStorage': RadosGWStorage, 'RadosGWStorage': RadosGWStorage,
'SwiftStorage': SwiftStorage, 'SwiftStorage': SwiftStorage,
'CloudFrontedS3Storage': CloudFrontedS3Storage,
} }
def get_storage_driver(location, metric_queue, chunk_cleanup_queue, storage_params): def get_storage_driver(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver, storage_params):
""" Returns a storage driver class for the given storage configuration """ Returns a storage driver class for the given storage configuration
(a pair of string name and a dict of parameters). """ (a pair of string name and a dict of parameters). """
driver = storage_params[0] driver = storage_params[0]
parameters = storage_params[1] parameters = storage_params[1]
driver_class = STORAGE_DRIVER_CLASSES.get(driver, FakeStorage) driver_class = STORAGE_DRIVER_CLASSES.get(driver, FakeStorage)
context = StorageContext(location, metric_queue, chunk_cleanup_queue) context = StorageContext(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver)
return driver_class(context, **parameters) return driver_class(context, **parameters)
class StorageContext(object): class StorageContext(object):
def __init__(self, location, metric_queue, chunk_cleanup_queue): def __init__(self, location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver):
self.location = location self.location = location
self.metric_queue = metric_queue self.metric_queue = metric_queue
self.chunk_cleanup_queue = chunk_cleanup_queue self.chunk_cleanup_queue = chunk_cleanup_queue
self.config_provider = config_provider
self.ip_resolver = ip_resolver
class Storage(object): class Storage(object):
def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None): def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None,
config_provider=None, ip_resolver=None):
self.app = app self.app = app
if app is not None: if app is not None:
self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys) self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys,
config_provider, ip_resolver)
else: else:
self.state = None self.state = None
def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys): def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver):
storages = {} storages = {}
for location, storage_params in app.config.get('DISTRIBUTED_STORAGE_CONFIG').items(): for location, storage_params in app.config.get('DISTRIBUTED_STORAGE_CONFIG').items():
storages[location] = get_storage_driver(location, metric_queue, chunk_cleanup_queue, storages[location] = get_storage_driver(location, metric_queue, chunk_cleanup_queue,
storage_params) config_provider, ip_resolver, storage_params)
preference = app.config.get('DISTRIBUTED_STORAGE_PREFERENCE', None) preference = app.config.get('DISTRIBUTED_STORAGE_PREFERENCE', None)
if not preference: if not preference:

View file

@ -49,7 +49,7 @@ class BaseStorage(StoragePaths):
if not self.exists('_verify'): if not self.exists('_verify'):
raise Exception('Could not find verification file') raise Exception('Could not find verification file')
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
return None return None
def get_direct_upload_url(self, path, mime_type, requires_cors=True): def get_direct_upload_url(self, path, mime_type, requires_cors=True):

View file

@ -3,8 +3,17 @@ import os
import logging import logging
import copy import copy
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cachetools import lru_cache
from itertools import chain from itertools import chain
from datetime import datetime, timedelta
from botocore.signers import CloudFrontSigner
from boto.exception import S3ResponseError from boto.exception import S3ResponseError
import boto.s3.connection import boto.s3.connection
import boto.s3.multipart import boto.s3.multipart
@ -119,7 +128,7 @@ class _CloudStorage(BaseStorageV2):
def get_supports_resumable_downloads(self): def get_supports_resumable_downloads(self):
return True return True
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
self._initialize_cloud_conn() self._initialize_cloud_conn()
path = self._init_path(path) path = self._init_path(path)
k = self._key_class(self._cloud_bucket, path) k = self._key_class(self._cloud_bucket, path)
@ -568,11 +577,11 @@ class RadosGWStorage(_CloudStorage):
storage_path, bucket_name, access_key, secret_key) storage_path, bucket_name, access_key, secret_key)
# TODO remove when radosgw supports cors: http://tracker.ceph.com/issues/8718#change-38624 # TODO remove when radosgw supports cors: http://tracker.ceph.com/issues/8718#change-38624
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
if requires_cors: if requires_cors:
return None return None
return super(RadosGWStorage, self).get_direct_download_url(path, expires_in, requires_cors, return super(RadosGWStorage, self).get_direct_download_url(path, request_ip, expires_in, requires_cors,
head) head)
# TODO remove when radosgw supports cors: http://tracker.ceph.com/issues/8718#change-38624 # TODO remove when radosgw supports cors: http://tracker.ceph.com/issues/8718#change-38624
@ -590,3 +599,58 @@ class RadosGWStorage(_CloudStorage):
# See https://github.com/ceph/ceph/pull/5139 # See https://github.com/ceph/ceph/pull/5139
chunk_list = self._chunk_list_from_metadata(storage_metadata) chunk_list = self._chunk_list_from_metadata(storage_metadata)
self._client_side_chunk_join(final_path, chunk_list) self._client_side_chunk_join(final_path, chunk_list)
class CloudFrontedS3Storage(S3Storage):
""" An S3Storage engine that redirects to CloudFront for all requests outside of AWS. """
def __init__(self, context, cloudfront_distribution_domain, cloudfront_key_id,
cloudfront_privatekey_filename, storage_path, s3_bucket, *args, **kwargs):
super(CloudFrontedS3Storage, self).__init__(context, storage_path, s3_bucket, *args, **kwargs)
self.cloudfront_distribution_domain = cloudfront_distribution_domain
self.cloudfront_key_id = cloudfront_key_id
self.cloudfront_privatekey = self._load_private_key(cloudfront_privatekey_filename)
def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
logger.debug('Got direct download request for path "%s" with IP "%s"', path, request_ip)
if request_ip is not None:
# Lookup the IP address in our resolution table and determine whether it is under AWS. If it is,
# then return an S3 signed URL, since we are in-network.
resolved_ip_info = self._context.ip_resolver.resolve_ip(request_ip)
logger.debug('Resolved IP information for IP %s: %s', request_ip, resolved_ip_info)
if resolved_ip_info and resolved_ip_info.provider == 'aws':
return super(CloudFrontedS3Storage, self).get_direct_download_url(path, request_ip, expires_in, requires_cors,
head)
url = 'https://%s/%s' % (self.cloudfront_distribution_domain, path)
expire_date = datetime.now() + timedelta(seconds=expires_in)
signer = self._get_cloudfront_signer()
signed_url = signer.generate_presigned_url(url, date_less_than=expire_date)
logger.debug('Returning CloudFront URL for path "%s" with IP "%s": %s', path, resolved_ip_info, signed_url)
return signed_url
@lru_cache(maxsize=1)
def _get_cloudfront_signer(self):
return CloudFrontSigner(self.cloudfront_key_id, self._get_rsa_signer())
@lru_cache(maxsize=1)
def _get_rsa_signer(self):
private_key = self.cloudfront_privatekey
def handler(message):
signer = private_key.signer(padding.PKCS1v15(), hashes.SHA1())
signer.update(message)
return signer.finalize()
return handler
@lru_cache(maxsize=1)
def _load_private_key(self, cloudfront_privatekey_filename):
""" Returns the private key, loaded from the config provider, used to sign direct
download URLs to CloudFront.
"""
with self._context.config_provider.get_volume_file(cloudfront_privatekey_filename) as key_file:
return serialization.load_pem_private_key(
key_file.read(),
password=None,
backend=default_backend()
)

View file

@ -56,9 +56,9 @@ class DistributedStorage(StoragePaths):
cancel_chunked_upload = _location_aware(BaseStorageV2.cancel_chunked_upload) cancel_chunked_upload = _location_aware(BaseStorageV2.cancel_chunked_upload)
def get_direct_download_url(self, locations, path, expires_in=600, requires_cors=False, def get_direct_download_url(self, locations, path, request_ip=None, expires_in=600, requires_cors=False,
head=False): head=False):
download_url = self._get_direct_download_url(locations, path, expires_in, requires_cors, head) download_url = self._get_direct_download_url(locations, path, request_ip, expires_in, requires_cors, head)
if download_url is None: if download_url is None:
return None return None

View file

@ -15,7 +15,7 @@ class FakeStorage(BaseStorageV2):
def _init_path(self, path=None, create=False): def _init_path(self, path=None, create=False):
return path return path
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
try: try:
if self.get_content('supports_direct_download') == 'true': if self.get_content('supports_direct_download') == 'true':
return 'http://somefakeurl?goes=here' return 'http://somefakeurl?goes=here'

View file

@ -147,7 +147,7 @@ class SwiftStorage(BaseStorage):
logger.exception('Could not head object at path %s: %s', path, ex) logger.exception('Could not head object at path %s: %s', path, ex)
return None return None
def get_direct_download_url(self, object_path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, object_path, request_ip=None, expires_in=60, requires_cors=False, head=False):
if requires_cors: if requires_cors:
return None return None

View file

@ -0,0 +1,54 @@
import pytest
from contextlib import contextmanager
from mock import patch
from moto import mock_s3
import boto
from app import config_provider
from storage import CloudFrontedS3Storage, StorageContext
from util.ipresolver import IPResolver
from util.ipresolver.test.test_ipresolver import test_aws_ip, aws_ip_range_data
from test.fixtures import *
_TEST_CONTENT = os.urandom(1024)
_TEST_BUCKET = 'some_bucket'
_TEST_USER = 'someuser'
_TEST_PASSWORD = 'somepassword'
_TEST_PATH = 'some/cool/path'
@pytest.fixture(params=[True, False])
def ipranges_populated(request):
return request.param
@mock_s3
def test_direct_download(test_aws_ip, aws_ip_range_data, ipranges_populated, app):
ipresolver = IPResolver(app)
if ipranges_populated:
empty_range_data = {
'syncToken': 123456789,
'prefixes': [],
}
with patch.object(ipresolver, '_get_aws_ip_ranges', lambda: aws_ip_range_data if ipranges_populated else empty_range_data):
context = StorageContext('nyc', None, None, config_provider, ipresolver)
# Create a test bucket and put some test content.
boto.connect_s3().create_bucket(_TEST_BUCKET)
engine = CloudFrontedS3Storage(context, 'cloudfrontdomain', 'keyid', 'test/data/test.pem', 'some/path',
_TEST_BUCKET, _TEST_USER, _TEST_PASSWORD)
engine.put_content(_TEST_PATH, _TEST_CONTENT)
assert engine.exists(_TEST_PATH)
# Request a direct download URL for a request from a known AWS IP, and ensure we are returned an S3 URL.
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, test_aws_ip)
if ipranges_populated:
# Request a direct download URL for a request from a non-AWS IP, and ensure we are returned a CloudFront URL.
assert 'cloudfrontdomain' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')
else:
# Request a direct download URL for a request from a non-AWS IP, but since IP Ranges isn't populated, we still
# get back an S3 URL.
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')

View file

@ -9,7 +9,7 @@ from storage import StorageContext
from storage.swift import SwiftStorage from storage.swift import SwiftStorage
base_args = { base_args = {
'context': StorageContext('nyc', None, None), 'context': StorageContext('nyc', None, None, None, None),
'swift_container': 'container-name', 'swift_container': 'container-name',
'storage_path': '/basepath', 'storage_path': '/basepath',
'auth_url': 'https://auth.com', 'auth_url': 'https://auth.com',
@ -191,7 +191,7 @@ def test_cancel_chunked_upload():
def test_empty_chunks_queued_for_deletion(): def test_empty_chunks_queued_for_deletion():
chunk_cleanup_queue = FakeQueue() chunk_cleanup_queue = FakeQueue()
args = dict(base_args) args = dict(base_args)
args['context'] = StorageContext('nyc', None, chunk_cleanup_queue) args['context'] = StorageContext('nyc', None, chunk_cleanup_queue, None, None)
swift = FakeSwiftStorage(**args) swift = FakeSwiftStorage(**args)
uuid, metadata = swift.initiate_chunked_upload() uuid, metadata = swift.initiate_chunked_upload()

View file

@ -13,7 +13,7 @@ _TEST_BUCKET = 'some_bucket'
_TEST_USER = 'someuser' _TEST_USER = 'someuser'
_TEST_PASSWORD = 'somepassword' _TEST_PASSWORD = 'somepassword'
_TEST_PATH = 'some/cool/path' _TEST_PATH = 'some/cool/path'
_TEST_CONTEXT = StorageContext('nyc', None, None) _TEST_CONTEXT = StorageContext('nyc', None, None, None, None)
class TestCloudStorage(unittest.TestCase): class TestCloudStorage(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -5,7 +5,7 @@ from urlparse import urlparse
from flask import request from flask import request
from app import analytics, userevents from app import analytics, userevents, ip_resolver
from data import model from data import model
from auth.registry_jwt_auth import get_granted_entity from auth.registry_jwt_auth import get_granted_entity
from auth.auth_context import (get_authenticated_user, get_validated_token, from auth.auth_context import (get_authenticated_user, get_validated_token,
@ -85,6 +85,11 @@ def track_and_log(event_name, repo_obj, analytics_name=None, analytics_sample=1,
analytics.track(analytics_id, analytics_name, extra_params) analytics.track(analytics_id, analytics_name, extra_params)
# Add the resolved information to the metadata.
resolved_ip = ip_resolver.resolve_ip(request.remote_addr)
if resolved_ip is not None:
metadata['resolved_ip'] = resolved_ip._asdict()
# Log the action to the database. # Log the action to the database.
logger.debug('Logging the %s to logs system', event_name) logger.debug('Logging the %s to logs system', event_name)
model.log.log_action(event_name, namespace_name, performer=authenticated_user, model.log.log_action(event_name, namespace_name, performer=authenticated_user,

View file

@ -7,7 +7,7 @@ from util.config.provider.baseprovider import BaseProvider
from util.license import (EntitlementValidationResult, Entitlement, Expiration, ExpirationType, from util.license import (EntitlementValidationResult, Entitlement, Expiration, ExpirationType,
EntitlementRequirement) EntitlementRequirement)
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg'] REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
class TestLicense(object): class TestLicense(object):
def validate_entitlement_requirement(self, entitlement_req, check_time): def validate_entitlement_requirement(self, entitlement_req, check_time):

View file

@ -1,4 +1,4 @@
from app import app from app import app, ip_resolver, config_provider
from storage import get_storage_driver from storage import get_storage_driver
from util.config.validators import BaseValidator, ConfigValidationException from util.config.validators import BaseValidator, ConfigValidationException
@ -36,7 +36,8 @@ def _get_storage_providers(config):
try: try:
for name, parameters in storage_config.items(): for name, parameters in storage_config.items():
drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters)) driver = get_storage_driver(None, None, None, config_provider, ip_resolver, parameters)
drivers[name] = (parameters[0], driver)
except TypeError: except TypeError:
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name) raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)

Binary file not shown.

117
util/ipresolver/__init__.py Normal file
View file

@ -0,0 +1,117 @@
import logging
import json
import requests
from cachetools import ttl_cache, lru_cache
from collections import namedtuple, defaultdict
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError
import geoip2.database
import geoip2.errors
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token'])
logger = logging.getLogger(__name__)
_DATA_FILES = {'aws-ip-ranges.json': 'https://ip-ranges.amazonaws.com/ip-ranges.json'}
def update_resolver_datafiles():
""" Performs an update of the data file(s) used by the IP Resolver. """
for filename, url in _DATA_FILES.iteritems():
logger.debug('Updating IP resolver data file "%s" from URL "%s"', filename, url)
with open('util/ipresolver/%s' % filename, 'w') as f:
response = requests.get(url)
logger.debug('Got %s response for URL %s', response.status_code, url)
if response.status_code / 2 != 100:
raise Exception('Got non-2XX status code for URL %s: %s' % (url, response.status_code))
f.write(response.text)
logger.debug('Successfully wrote %s', filename)
class IPResolver(object):
def __init__(self, app):
self.app = app
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
def resolve_ip(self, ip_address):
""" Attempts to return resolved information about the specified IP Address. If such an attempt fails,
returns None.
"""
location_function = self._get_location_function()
if not ip_address or not location_function:
return None
return location_function(ip_address)
def _get_aws_ip_ranges(self):
try:
with open('util/ipresolver/aws-ip-ranges.json', 'r') as f:
return json.loads(f.read())
except IOError:
logger.exception('Could not load AWS IP Ranges')
return None
except ValueError:
logger.exception('Could not load AWS IP Ranges')
return None
except TypeError:
logger.exception('Could not load AWS IP Ranges')
return None
@ttl_cache(maxsize=1, ttl=600)
def _get_location_function(self):
aws_ip_range_json = self._get_aws_ip_ranges()
if aws_ip_range_json is None:
return None
sync_token = aws_ip_range_json['syncToken']
all_amazon, regions, services = IPResolver._parse_amazon_ranges(aws_ip_range_json)
return IPResolver._build_location_function(sync_token, all_amazon, regions, services, self.geoip_db)
@staticmethod
def _build_location_function(sync_token, all_amazon, regions, country, country_db):
@lru_cache(maxsize=4096)
def _get_location(ip_address):
try:
parsed_ip = IPAddress(ip_address)
except AddrFormatError:
return ResolvedLocation('invalid_ip', None, None, sync_token)
if parsed_ip not in all_amazon:
# Try geoip classification
try:
found = country_db.country(parsed_ip)
return ResolvedLocation(
'internet',
found.continent.code,
found.country.iso_code,
sync_token,
)
except geoip2.errors.AddressNotFoundError:
return ResolvedLocation('internet', None, None, sync_token)
region = None
for region_name, region_set in regions.items():
if parsed_ip in region_set:
region = region_name
break
return ResolvedLocation('aws', region, None, sync_token)
return _get_location
@staticmethod
def _parse_amazon_ranges(ranges):
all_amazon = IPSet()
regions = defaultdict(IPSet)
services = defaultdict(IPSet)
for service_description in ranges['prefixes']:
cidr = IPNetwork(service_description['ip_prefix'])
service = service_description['service']
region = service_description['region']
all_amazon.add(cidr)
regions[region].add(cidr)
services[service].add(cidr)
return all_amazon, regions, services

View file

View file

@ -0,0 +1,40 @@
import pytest
from mock import patch
from util.ipresolver import IPResolver, ResolvedLocation
from test.fixtures import *
@pytest.fixture()
def test_aws_ip():
return '10.0.0.1'
@pytest.fixture()
def aws_ip_range_data():
fake_range_doc = {
'syncToken': 123456789,
'prefixes': [
{
'ip_prefix': '10.0.0.0/8',
'region': 'GLOBAL',
'service': 'AMAZON',
}
],
}
return fake_range_doc
def test_unstarted(app, test_aws_ip):
ipresolver = IPResolver(app)
assert ipresolver.resolve_ip(test_aws_ip) is None
def test_resolved(aws_ip_range_data, test_aws_ip, app,):
ipresolver = IPResolver(app)
def get_data():
return aws_ip_range_data
with patch.object(ipresolver, '_get_aws_ip_ranges', get_data):
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)

View file

@ -0,0 +1,45 @@
import logging
import time
from app import app
from util.ipresolver import update_resolver_datafiles
from workers.worker import Worker
logger = logging.getLogger(__name__)
class IPResolverUpdateWorker(Worker):
def __init__(self):
super(IPResolverUpdateWorker, self).__init__()
# Update now.
try:
self._update_resolver_datafiles()
except:
logger.exception('Initial update of range data files failed')
self.add_operation(self._update_resolver_datafiles,
app.config.get('IP_RESOLVER_DATAFILE_REFRESH', 60 * 60 * 2) * 60)
def _update_resolver_datafiles(self):
logger.debug('Starting refresh of IP resolver data files')
update_resolver_datafiles()
logger.debug('Finished refresh of IP resolver data files')
if __name__ == "__main__":
# Only enable if CloudFronted storage is used.
requires_resolution = False
for storage_type, _ in app.config.get('DISTRIBUTED_STORAGE_CONFIG', {}).values():
if storage_type == 'CloudFrontedS3Storage':
requires_resolution = True
break
if not requires_resolution:
logger.debug('Cloud fronted storage not used; skipping')
while True:
time.sleep(10000)
worker = IPResolverUpdateWorker()
worker.start()