Add CloudFrontedS3Storage, which redirects to CloudFront for non-S3 ips

This commit is contained in:
Joseph Schorr 2017-09-26 16:08:50 -04:00
parent 2d522764f7
commit 010dda2c52
14 changed files with 175 additions and 69 deletions

4
app.py
View file

@ -35,6 +35,7 @@ from oauth.loginmanager import OAuthLoginManager
from storage import Storage from storage import Storage
from util.log import filter_logs from util.log import filter_logs
from util import get_app_url from util import get_app_url
from util.ipresolver import IPResolver
from util.saas.analytics import Analytics from util.saas.analytics import Analytics
from util.saas.useranalytics import UserAnalytics from util.saas.useranalytics import UserAnalytics
from util.saas.exceptionlog import Sentry from util.saas.exceptionlog import Sentry
@ -195,7 +196,8 @@ prometheus = PrometheusPlugin(app)
metric_queue = MetricQueue(prometheus) metric_queue = MetricQueue(prometheus)
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue) chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue)
instance_keys = InstanceKeys(app) instance_keys = InstanceKeys(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys) ip_resolver = IPResolver(app)
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver)
userfiles = Userfiles(app, storage) userfiles = Userfiles(app, storage)
log_archive = LogArchive(app, storage) log_archive = LogArchive(app, storage)
analytics = Analytics(app) analytics = Analytics(app)

View file

@ -23,6 +23,7 @@ bencode
bintrees bintrees
bitmath bitmath
boto boto
boto3
cachetools==1.1.6 cachetools==1.1.6
cryptography cryptography
flask flask

View file

@ -23,6 +23,7 @@ bintrees==2.0.6
bitmath==1.3.1.2 bitmath==1.3.1.2
blinker==1.4 blinker==1.4
boto==2.46.1 boto==2.46.1
boto3==1.4.7
cachetools==1.1.6 cachetools==1.1.6
certifi==2017.4.17 certifi==2017.4.17
cffi==1.10.0 cffi==1.10.0

View file

@ -1,5 +1,5 @@
from storage.local import LocalStorage from storage.local import LocalStorage
from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage, CloudFrontedS3Storage
from storage.fakestorage import FakeStorage from storage.fakestorage import FakeStorage
from storage.distributedstorage import DistributedStorage from storage.distributedstorage import DistributedStorage
from storage.swift import SwiftStorage from storage.swift import SwiftStorage
@ -11,39 +11,44 @@ STORAGE_DRIVER_CLASSES = {
'GoogleCloudStorage': GoogleCloudStorage, 'GoogleCloudStorage': GoogleCloudStorage,
'RadosGWStorage': RadosGWStorage, 'RadosGWStorage': RadosGWStorage,
'SwiftStorage': SwiftStorage, 'SwiftStorage': SwiftStorage,
'CloudFrontedS3Storage': CloudFrontedS3Storage,
} }
def get_storage_driver(location, metric_queue, chunk_cleanup_queue, storage_params): def get_storage_driver(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver, storage_params):
""" Returns a storage driver class for the given storage configuration """ Returns a storage driver class for the given storage configuration
(a pair of string name and a dict of parameters). """ (a pair of string name and a dict of parameters). """
driver = storage_params[0] driver = storage_params[0]
parameters = storage_params[1] parameters = storage_params[1]
driver_class = STORAGE_DRIVER_CLASSES.get(driver, FakeStorage) driver_class = STORAGE_DRIVER_CLASSES.get(driver, FakeStorage)
context = StorageContext(location, metric_queue, chunk_cleanup_queue) context = StorageContext(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver)
return driver_class(context, **parameters) return driver_class(context, **parameters)
class StorageContext(object): class StorageContext(object):
def __init__(self, location, metric_queue, chunk_cleanup_queue): def __init__(self, location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver):
self.location = location self.location = location
self.metric_queue = metric_queue self.metric_queue = metric_queue
self.chunk_cleanup_queue = chunk_cleanup_queue self.chunk_cleanup_queue = chunk_cleanup_queue
self.config_provider = config_provider
self.ip_resolver = ip_resolver
class Storage(object): class Storage(object):
def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None): def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None,
config_provider=None, ip_resolver=None):
self.app = app self.app = app
if app is not None: if app is not None:
self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys) self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys,
config_provider, ip_resolver)
else: else:
self.state = None self.state = None
def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys): def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver):
storages = {} storages = {}
for location, storage_params in app.config.get('DISTRIBUTED_STORAGE_CONFIG').items(): for location, storage_params in app.config.get('DISTRIBUTED_STORAGE_CONFIG').items():
storages[location] = get_storage_driver(location, metric_queue, chunk_cleanup_queue, storages[location] = get_storage_driver(location, metric_queue, chunk_cleanup_queue,
storage_params) config_provider, ip_resolver, storage_params)
preference = app.config.get('DISTRIBUTED_STORAGE_PREFERENCE', None) preference = app.config.get('DISTRIBUTED_STORAGE_PREFERENCE', None)
if not preference: if not preference:

View file

@ -49,7 +49,7 @@ class BaseStorage(StoragePaths):
if not self.exists('_verify'): if not self.exists('_verify'):
raise Exception('Could not find verification file') raise Exception('Could not find verification file')
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False): def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
return None return None
def get_direct_upload_url(self, path, mime_type, requires_cors=True): def get_direct_upload_url(self, path, mime_type, requires_cors=True):

View file

@ -3,8 +3,17 @@ import os
import logging import logging
import copy import copy
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cachetools import lru_cache
from itertools import chain from itertools import chain
from datetime import datetime, timedelta
from botocore.signers import CloudFrontSigner
from boto.exception import S3ResponseError from boto.exception import S3ResponseError
import boto.s3.connection import boto.s3.connection
import boto.s3.multipart import boto.s3.multipart
@ -590,3 +599,58 @@ class RadosGWStorage(_CloudStorage):
# See https://github.com/ceph/ceph/pull/5139 # See https://github.com/ceph/ceph/pull/5139
chunk_list = self._chunk_list_from_metadata(storage_metadata) chunk_list = self._chunk_list_from_metadata(storage_metadata)
self._client_side_chunk_join(final_path, chunk_list) self._client_side_chunk_join(final_path, chunk_list)
class CloudFrontedS3Storage(S3Storage):
""" An S3Storage engine that redirects to CloudFront for all requests outside of AWS. """
def __init__(self, context, cloudfront_distribution_domain, cloudfront_key_id,
cloudfront_privatekey_filename, storage_path, s3_bucket, *args, **kwargs):
super(CloudFrontedS3Storage, self).__init__(context, storage_path, s3_bucket, *args, **kwargs)
self.cloudfront_distribution_domain = cloudfront_distribution_domain
self.cloudfront_key_id = cloudfront_key_id
self.cloudfront_privatekey = self._load_private_key(cloudfront_privatekey_filename)
def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
logger.debug('Got direct download request for path "%s" with IP "%s"', path, request_ip)
if request_ip is not None:
# Lookup the IP address in our resolution table and determine whether it is under AWS. If it is *not*,
# then return a CloudFront signed URL.
resolved_ip_info = self._context.ip_resolver.resolve_ip(request_ip)
logger.debug('Resolved IP information for IP %s: %s', request_ip, resolved_ip_info)
if resolved_ip_info and resolved_ip_info.provider != 'aws':
url = 'https://%s/%s' % (self.cloudfront_distribution_domain, path)
expire_date = datetime.now() + timedelta(seconds=expires_in)
signer = self._get_cloudfront_signer()
signed_url = signer.generate_presigned_url(url, date_less_than=expire_date)
logger.debug('Returning CloudFront URL for path "%s" with IP "%s": %s', path, resolved_ip_info, signed_url)
return signed_url
return super(CloudFrontedS3Storage, self).get_direct_download_url(path, request_ip, expires_in, requires_cors,
head)
@lru_cache(maxsize=1)
def _get_cloudfront_signer(self):
return CloudFrontSigner(self.cloudfront_key_id, self._get_rsa_signer())
@lru_cache(maxsize=1)
def _get_rsa_signer(self):
private_key = self.cloudfront_privatekey
def handler(message):
signer = private_key.signer(padding.PKCS1v15(), hashes.SHA1())
signer.update(message)
return signer.finalize()
return handler
@lru_cache(maxsize=1)
def _load_private_key(self, cloudfront_privatekey_filename):
""" Returns the private key, loaded from the config provider, used to sign direct
download URLs to CloudFront.
"""
with self._context.config_provider.get_volume_file(cloudfront_privatekey_filename) as key_file:
return serialization.load_pem_private_key(
key_file.read(),
password=None,
backend=default_backend()
)

View file

@ -0,0 +1,56 @@
import pytest
from httmock import urlmatch, HTTMock
from moto import mock_s3
import boto
from app import config_provider
from storage import CloudFrontedS3Storage, StorageContext
from util.ipresolver import IPResolver
from util.ipresolver.test.test_ipresolver import http_client, test_aws_ip, aws_ip_range_handler
from test.fixtures import *
_TEST_CONTENT = os.urandom(1024)
_TEST_BUCKET = 'some_bucket'
_TEST_USER = 'someuser'
_TEST_PASSWORD = 'somepassword'
_TEST_PATH = 'some/cool/path'
@pytest.fixture(params=[True, False])
def ipranges_populated(request):
return request.param
@pytest.fixture()
def ipresolver(http_client, aws_ip_range_handler, ipranges_populated, app):
with HTTMock(aws_ip_range_handler):
ipresolver = IPResolver(app, client=http_client)
if ipranges_populated:
assert ipresolver._update_aws_ip_range()
return ipresolver
@pytest.fixture()
def storage_context(ipresolver, app):
return StorageContext('nyc', None, None, config_provider, ipresolver)
@mock_s3
def test_direct_download(storage_context, test_aws_ip, ipranges_populated, app):
# Create a test bucket and put some test content.
boto.connect_s3().create_bucket(_TEST_BUCKET)
engine = CloudFrontedS3Storage(storage_context, 'cloudfrontdomain', 'keyid', 'test/data/test.pem', 'some/path',
_TEST_BUCKET, _TEST_USER, _TEST_PASSWORD)
engine.put_content(_TEST_PATH, _TEST_CONTENT)
assert engine.exists(_TEST_PATH)
# Request a direct download URL for a request from a known AWS IP, and ensure we are returned an S3 URL.
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, test_aws_ip)
if ipranges_populated:
# Request a direct download URL for a request from a non-AWS IP, and ensure we are returned a CloudFront URL.
assert 'cloudfrontdomain' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')
else:
# Request a direct download URL for a request from a non-AWS IP, but since IP Ranges isn't populated, we still
# get back an S3 URL.
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')

View file

@ -9,7 +9,7 @@ from storage import StorageContext
from storage.swift import SwiftStorage from storage.swift import SwiftStorage
base_args = { base_args = {
'context': StorageContext('nyc', None, None), 'context': StorageContext('nyc', None, None, None, None),
'swift_container': 'container-name', 'swift_container': 'container-name',
'storage_path': '/basepath', 'storage_path': '/basepath',
'auth_url': 'https://auth.com', 'auth_url': 'https://auth.com',
@ -191,7 +191,7 @@ def test_cancel_chunked_upload():
def test_empty_chunks_queued_for_deletion(): def test_empty_chunks_queued_for_deletion():
chunk_cleanup_queue = FakeQueue() chunk_cleanup_queue = FakeQueue()
args = dict(base_args) args = dict(base_args)
args['context'] = StorageContext('nyc', None, chunk_cleanup_queue) args['context'] = StorageContext('nyc', None, chunk_cleanup_queue, None, None)
swift = FakeSwiftStorage(**args) swift = FakeSwiftStorage(**args)
uuid, metadata = swift.initiate_chunked_upload() uuid, metadata = swift.initiate_chunked_upload()

View file

@ -13,7 +13,7 @@ _TEST_BUCKET = 'some_bucket'
_TEST_USER = 'someuser' _TEST_USER = 'someuser'
_TEST_PASSWORD = 'somepassword' _TEST_PASSWORD = 'somepassword'
_TEST_PATH = 'some/cool/path' _TEST_PATH = 'some/cool/path'
_TEST_CONTEXT = StorageContext('nyc', None, None) _TEST_CONTEXT = StorageContext('nyc', None, None, None, None)
class TestCloudStorage(unittest.TestCase): class TestCloudStorage(unittest.TestCase):
def setUp(self): def setUp(self):

View file

@ -7,7 +7,7 @@ from util.config.provider.baseprovider import BaseProvider
from util.license import (EntitlementValidationResult, Entitlement, Expiration, ExpirationType, from util.license import (EntitlementValidationResult, Entitlement, Expiration, ExpirationType,
EntitlementRequirement) EntitlementRequirement)
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg'] REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
class TestLicense(object): class TestLicense(object):
def validate_entitlement_requirement(self, entitlement_req, check_time): def validate_entitlement_requirement(self, entitlement_req, check_time):

View file

@ -1,4 +1,4 @@
from app import app from app import app, ip_resolver, config_provider
from storage import get_storage_driver from storage import get_storage_driver
from util.config.validators import BaseValidator, ConfigValidationException from util.config.validators import BaseValidator, ConfigValidationException
@ -36,7 +36,8 @@ def _get_storage_providers(config):
try: try:
for name, parameters in storage_config.items(): for name, parameters in storage_config.items():
drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters)) driver = get_storage_driver(None, None, None, config_provider, ip_resolver, parameters)
drivers[name] = (parameters[0], driver)
except TypeError: except TypeError:
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name) raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)

View file

@ -1,67 +1,50 @@
import logging import logging
import time import json
from cachetools import lru_cache from cachetools import ttl_cache, lru_cache
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError
from threading import Thread
import geoip2.database import geoip2.database
import geoip2.errors import geoip2.errors
_AWS_IP_RANGES_URL = 'https://ip-ranges.amazonaws.com/ip-ranges.json'
_UPDATE_TIME = 60 * 60 * 24
_RETRY_TIME = 60 * 60 * 5
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token']) ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token'])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IPResolver(Thread): class IPResolver(object):
def __init__(self, app, client=None, *args, **kwargs): def __init__(self, app, *args, **kwargs):
super(IPResolver, self).__init__(*args, **kwargs)
self.daemon = True
self.app = app self.app = app
self.client = client or app.config['HTTPCLIENT']
self.location_function = None
self.sync_token = None
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb') self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
def resolve_ip(self, ip_address): def resolve_ip(self, ip_address):
""" Attempts to return resolved information about the specified IP Address. If such an attempt fails, """ Attempts to return resolved information about the specified IP Address. If such an attempt fails,
returns None. returns None.
""" """
location_function = self.location_function location_function = self._get_location_function()
if not ip_address or not location_function: if not ip_address or not location_function:
return None return None
return location_function(ip_address) return location_function(ip_address)
def _update_aws_ip_range(self): @ttl_cache(maxsize=1, ttl=600)
logger.debug('Starting download of AWS IP Range table from %s', _AWS_IP_RANGES_URL) def _get_location_function(self):
try: try:
response = self.client.get(_AWS_IP_RANGES_URL) with open('util/ipresolver/ip-ranges.json', 'r') as f:
if response.status_code / 100 != 2: ip_range_json = json.loads(f.read())
logger.error('Non-200 response (%s) for AWS IP Range table request', response.status_code) except IOError:
return False logger.exception('Could not load IP Ranges')
except: return None
logger.exception('Could not download AWS IP range table') except ValueError:
return False logger.exception('Could not load IP Ranges')
return None
# Check if the sync token is the same. If so, no updates are necessary. except TypeError:
if self.sync_token and response.json()['syncToken'] == self.sync_token: logger.exception('Could not load IP Ranges')
logger.debug('No updates necessary') return None
return True
# Otherwise, update the range lookup function. sync_token = ip_range_json['syncToken']
all_amazon, regions, services = IPResolver._parse_amazon_ranges(response.json()) all_amazon, regions, services = IPResolver._parse_amazon_ranges(ip_range_json)
self.sync_token = response.json()['syncToken'] return IPResolver._build_location_function(sync_token, all_amazon, regions, services, self.geoip_db)
self.location_function = IPResolver._build_location_function(self.sync_token, all_amazon, regions, services, self.geoip_db)
logger.debug('Successfully updated AWS IP range table with sync token: %s', self.sync_token)
return True
@staticmethod @staticmethod
def _build_location_function(sync_token, all_amazon, regions, country, country_db): def _build_location_function(sync_token, all_amazon, regions, country, country_db):
@ -111,14 +94,3 @@ class IPResolver(Thread):
services[service].add(cidr) services[service].add(cidr)
return all_amazon, regions, services return all_amazon, regions, services
def run(self):
while True:
logger.debug('Updating AWS IP database')
if not self._update_aws_ip_range():
logger.debug('Failed; sleeping for %s seconds', _RETRY_TIME)
time.sleep(_RETRY_TIME)
continue
logger.debug('Success; sleeping for %s seconds', _UPDATE_TIME)
time.sleep(_UPDATE_TIME)

View file

View file

@ -17,6 +17,10 @@ def http_client():
sess.mount('https://', adapter) sess.mount('https://', adapter)
return sess return sess
@pytest.fixture()
def test_aws_ip():
return '10.0.0.1'
@pytest.fixture() @pytest.fixture()
def aws_ip_range_handler(): def aws_ip_range_handler():
@urlmatch(netloc=r'ip-ranges.amazonaws.com') @urlmatch(netloc=r'ip-ranges.amazonaws.com')
@ -35,16 +39,16 @@ def aws_ip_range_handler():
return handler return handler
def test_unstarted(app, http_client): def test_unstarted(app, test_aws_ip, http_client):
ipresolver = IPResolver(app, client=http_client) ipresolver = IPResolver(app, client=http_client)
assert ipresolver.resolve_ip('10.0.0.1') is None assert ipresolver.resolve_ip(test_aws_ip) is None
def test_resolved(aws_ip_range_handler, app, http_client): def test_resolved(aws_ip_range_handler, test_aws_ip, app, http_client):
with HTTMock(aws_ip_range_handler): with HTTMock(aws_ip_range_handler):
ipresolver = IPResolver(app, client=http_client) ipresolver = IPResolver(app, client=http_client)
assert ipresolver._update_aws_ip_range() assert ipresolver._update_aws_ip_range()
assert ipresolver.resolve_ip('10.0.0.1') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789) assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789) assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789) assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)