Add CloudFrontedS3Storage, which redirects to CloudFront for non-S3 ips
This commit is contained in:
parent
2d522764f7
commit
010dda2c52
14 changed files with 175 additions and 69 deletions
4
app.py
4
app.py
|
@ -35,6 +35,7 @@ from oauth.loginmanager import OAuthLoginManager
|
|||
from storage import Storage
|
||||
from util.log import filter_logs
|
||||
from util import get_app_url
|
||||
from util.ipresolver import IPResolver
|
||||
from util.saas.analytics import Analytics
|
||||
from util.saas.useranalytics import UserAnalytics
|
||||
from util.saas.exceptionlog import Sentry
|
||||
|
@ -195,7 +196,8 @@ prometheus = PrometheusPlugin(app)
|
|||
metric_queue = MetricQueue(prometheus)
|
||||
chunk_cleanup_queue = WorkQueue(app.config['CHUNK_CLEANUP_QUEUE_NAME'], tf, metric_queue=metric_queue)
|
||||
instance_keys = InstanceKeys(app)
|
||||
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys)
|
||||
ip_resolver = IPResolver(app)
|
||||
storage = Storage(app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver)
|
||||
userfiles = Userfiles(app, storage)
|
||||
log_archive = LogArchive(app, storage)
|
||||
analytics = Analytics(app)
|
||||
|
|
|
@ -23,6 +23,7 @@ bencode
|
|||
bintrees
|
||||
bitmath
|
||||
boto
|
||||
boto3
|
||||
cachetools==1.1.6
|
||||
cryptography
|
||||
flask
|
||||
|
|
|
@ -23,6 +23,7 @@ bintrees==2.0.6
|
|||
bitmath==1.3.1.2
|
||||
blinker==1.4
|
||||
boto==2.46.1
|
||||
boto3==1.4.7
|
||||
cachetools==1.1.6
|
||||
certifi==2017.4.17
|
||||
cffi==1.10.0
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from storage.local import LocalStorage
|
||||
from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage
|
||||
from storage.cloud import S3Storage, GoogleCloudStorage, RadosGWStorage, CloudFrontedS3Storage
|
||||
from storage.fakestorage import FakeStorage
|
||||
from storage.distributedstorage import DistributedStorage
|
||||
from storage.swift import SwiftStorage
|
||||
|
@ -11,39 +11,44 @@ STORAGE_DRIVER_CLASSES = {
|
|||
'GoogleCloudStorage': GoogleCloudStorage,
|
||||
'RadosGWStorage': RadosGWStorage,
|
||||
'SwiftStorage': SwiftStorage,
|
||||
'CloudFrontedS3Storage': CloudFrontedS3Storage,
|
||||
}
|
||||
|
||||
|
||||
def get_storage_driver(location, metric_queue, chunk_cleanup_queue, storage_params):
|
||||
def get_storage_driver(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver, storage_params):
|
||||
""" Returns a storage driver class for the given storage configuration
|
||||
(a pair of string name and a dict of parameters). """
|
||||
driver = storage_params[0]
|
||||
parameters = storage_params[1]
|
||||
driver_class = STORAGE_DRIVER_CLASSES.get(driver, FakeStorage)
|
||||
context = StorageContext(location, metric_queue, chunk_cleanup_queue)
|
||||
context = StorageContext(location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver)
|
||||
return driver_class(context, **parameters)
|
||||
|
||||
|
||||
class StorageContext(object):
|
||||
def __init__(self, location, metric_queue, chunk_cleanup_queue):
|
||||
def __init__(self, location, metric_queue, chunk_cleanup_queue, config_provider, ip_resolver):
|
||||
self.location = location
|
||||
self.metric_queue = metric_queue
|
||||
self.chunk_cleanup_queue = chunk_cleanup_queue
|
||||
self.config_provider = config_provider
|
||||
self.ip_resolver = ip_resolver
|
||||
|
||||
|
||||
class Storage(object):
|
||||
def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None):
|
||||
def __init__(self, app=None, metric_queue=None, chunk_cleanup_queue=None, instance_keys=None,
|
||||
config_provider=None, ip_resolver=None):
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys)
|
||||
self.state = self.init_app(app, metric_queue, chunk_cleanup_queue, instance_keys,
|
||||
config_provider, ip_resolver)
|
||||
else:
|
||||
self.state = None
|
||||
|
||||
def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys):
|
||||
def init_app(self, app, metric_queue, chunk_cleanup_queue, instance_keys, config_provider, ip_resolver):
|
||||
storages = {}
|
||||
for location, storage_params in app.config.get('DISTRIBUTED_STORAGE_CONFIG').items():
|
||||
storages[location] = get_storage_driver(location, metric_queue, chunk_cleanup_queue,
|
||||
storage_params)
|
||||
config_provider, ip_resolver, storage_params)
|
||||
|
||||
preference = app.config.get('DISTRIBUTED_STORAGE_PREFERENCE', None)
|
||||
if not preference:
|
||||
|
|
|
@ -49,7 +49,7 @@ class BaseStorage(StoragePaths):
|
|||
if not self.exists('_verify'):
|
||||
raise Exception('Could not find verification file')
|
||||
|
||||
def get_direct_download_url(self, path, expires_in=60, requires_cors=False, head=False):
|
||||
def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
|
||||
return None
|
||||
|
||||
def get_direct_upload_url(self, path, mime_type, requires_cors=True):
|
||||
|
|
|
@ -3,8 +3,17 @@ import os
|
|||
import logging
|
||||
import copy
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
from cachetools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from botocore.signers import CloudFrontSigner
|
||||
from boto.exception import S3ResponseError
|
||||
import boto.s3.connection
|
||||
import boto.s3.multipart
|
||||
|
@ -590,3 +599,58 @@ class RadosGWStorage(_CloudStorage):
|
|||
# See https://github.com/ceph/ceph/pull/5139
|
||||
chunk_list = self._chunk_list_from_metadata(storage_metadata)
|
||||
self._client_side_chunk_join(final_path, chunk_list)
|
||||
|
||||
|
||||
class CloudFrontedS3Storage(S3Storage):
|
||||
""" An S3Storage engine that redirects to CloudFront for all requests outside of AWS. """
|
||||
def __init__(self, context, cloudfront_distribution_domain, cloudfront_key_id,
|
||||
cloudfront_privatekey_filename, storage_path, s3_bucket, *args, **kwargs):
|
||||
super(CloudFrontedS3Storage, self).__init__(context, storage_path, s3_bucket, *args, **kwargs)
|
||||
|
||||
self.cloudfront_distribution_domain = cloudfront_distribution_domain
|
||||
self.cloudfront_key_id = cloudfront_key_id
|
||||
self.cloudfront_privatekey = self._load_private_key(cloudfront_privatekey_filename)
|
||||
|
||||
def get_direct_download_url(self, path, request_ip=None, expires_in=60, requires_cors=False, head=False):
|
||||
logger.debug('Got direct download request for path "%s" with IP "%s"', path, request_ip)
|
||||
if request_ip is not None:
|
||||
# Lookup the IP address in our resolution table and determine whether it is under AWS. If it is *not*,
|
||||
# then return a CloudFront signed URL.
|
||||
resolved_ip_info = self._context.ip_resolver.resolve_ip(request_ip)
|
||||
logger.debug('Resolved IP information for IP %s: %s', request_ip, resolved_ip_info)
|
||||
if resolved_ip_info and resolved_ip_info.provider != 'aws':
|
||||
url = 'https://%s/%s' % (self.cloudfront_distribution_domain, path)
|
||||
expire_date = datetime.now() + timedelta(seconds=expires_in)
|
||||
signer = self._get_cloudfront_signer()
|
||||
signed_url = signer.generate_presigned_url(url, date_less_than=expire_date)
|
||||
logger.debug('Returning CloudFront URL for path "%s" with IP "%s": %s', path, resolved_ip_info, signed_url)
|
||||
return signed_url
|
||||
|
||||
return super(CloudFrontedS3Storage, self).get_direct_download_url(path, request_ip, expires_in, requires_cors,
|
||||
head)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_cloudfront_signer(self):
|
||||
return CloudFrontSigner(self.cloudfront_key_id, self._get_rsa_signer())
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_rsa_signer(self):
|
||||
private_key = self.cloudfront_privatekey
|
||||
def handler(message):
|
||||
signer = private_key.signer(padding.PKCS1v15(), hashes.SHA1())
|
||||
signer.update(message)
|
||||
return signer.finalize()
|
||||
|
||||
return handler
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_private_key(self, cloudfront_privatekey_filename):
|
||||
""" Returns the private key, loaded from the config provider, used to sign direct
|
||||
download URLs to CloudFront.
|
||||
"""
|
||||
with self._context.config_provider.get_volume_file(cloudfront_privatekey_filename) as key_file:
|
||||
return serialization.load_pem_private_key(
|
||||
key_file.read(),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
|
56
storage/test/test_cloudfront.py
Normal file
56
storage/test/test_cloudfront.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import pytest
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
from moto import mock_s3
|
||||
import boto
|
||||
|
||||
from app import config_provider
|
||||
from storage import CloudFrontedS3Storage, StorageContext
|
||||
from util.ipresolver import IPResolver
|
||||
from util.ipresolver.test.test_ipresolver import http_client, test_aws_ip, aws_ip_range_handler
|
||||
from test.fixtures import *
|
||||
|
||||
_TEST_CONTENT = os.urandom(1024)
|
||||
_TEST_BUCKET = 'some_bucket'
|
||||
_TEST_USER = 'someuser'
|
||||
_TEST_PASSWORD = 'somepassword'
|
||||
_TEST_PATH = 'some/cool/path'
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def ipranges_populated(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture()
|
||||
def ipresolver(http_client, aws_ip_range_handler, ipranges_populated, app):
|
||||
with HTTMock(aws_ip_range_handler):
|
||||
ipresolver = IPResolver(app, client=http_client)
|
||||
if ipranges_populated:
|
||||
assert ipresolver._update_aws_ip_range()
|
||||
|
||||
return ipresolver
|
||||
|
||||
@pytest.fixture()
|
||||
def storage_context(ipresolver, app):
|
||||
return StorageContext('nyc', None, None, config_provider, ipresolver)
|
||||
|
||||
@mock_s3
|
||||
def test_direct_download(storage_context, test_aws_ip, ipranges_populated, app):
|
||||
# Create a test bucket and put some test content.
|
||||
boto.connect_s3().create_bucket(_TEST_BUCKET)
|
||||
|
||||
engine = CloudFrontedS3Storage(storage_context, 'cloudfrontdomain', 'keyid', 'test/data/test.pem', 'some/path',
|
||||
_TEST_BUCKET, _TEST_USER, _TEST_PASSWORD)
|
||||
engine.put_content(_TEST_PATH, _TEST_CONTENT)
|
||||
assert engine.exists(_TEST_PATH)
|
||||
|
||||
# Request a direct download URL for a request from a known AWS IP, and ensure we are returned an S3 URL.
|
||||
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, test_aws_ip)
|
||||
|
||||
if ipranges_populated:
|
||||
# Request a direct download URL for a request from a non-AWS IP, and ensure we are returned a CloudFront URL.
|
||||
assert 'cloudfrontdomain' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')
|
||||
else:
|
||||
# Request a direct download URL for a request from a non-AWS IP, but since IP Ranges isn't populated, we still
|
||||
# get back an S3 URL.
|
||||
assert 's3.amazonaws.com' in engine.get_direct_download_url(_TEST_PATH, '1.2.3.4')
|
||||
|
|
@ -9,7 +9,7 @@ from storage import StorageContext
|
|||
from storage.swift import SwiftStorage
|
||||
|
||||
base_args = {
|
||||
'context': StorageContext('nyc', None, None),
|
||||
'context': StorageContext('nyc', None, None, None, None),
|
||||
'swift_container': 'container-name',
|
||||
'storage_path': '/basepath',
|
||||
'auth_url': 'https://auth.com',
|
||||
|
@ -191,7 +191,7 @@ def test_cancel_chunked_upload():
|
|||
def test_empty_chunks_queued_for_deletion():
|
||||
chunk_cleanup_queue = FakeQueue()
|
||||
args = dict(base_args)
|
||||
args['context'] = StorageContext('nyc', None, chunk_cleanup_queue)
|
||||
args['context'] = StorageContext('nyc', None, chunk_cleanup_queue, None, None)
|
||||
|
||||
swift = FakeSwiftStorage(**args)
|
||||
uuid, metadata = swift.initiate_chunked_upload()
|
||||
|
|
|
@ -13,7 +13,7 @@ _TEST_BUCKET = 'some_bucket'
|
|||
_TEST_USER = 'someuser'
|
||||
_TEST_PASSWORD = 'somepassword'
|
||||
_TEST_PATH = 'some/cool/path'
|
||||
_TEST_CONTEXT = StorageContext('nyc', None, None)
|
||||
_TEST_CONTEXT = StorageContext('nyc', None, None, None, None)
|
||||
|
||||
class TestCloudStorage(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ from util.config.provider.baseprovider import BaseProvider
|
|||
from util.license import (EntitlementValidationResult, Entitlement, Expiration, ExpirationType,
|
||||
EntitlementRequirement)
|
||||
|
||||
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg']
|
||||
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg', 'test/data/test.pem']
|
||||
|
||||
class TestLicense(object):
|
||||
def validate_entitlement_requirement(self, entitlement_req, check_time):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from app import app
|
||||
from app import app, ip_resolver, config_provider
|
||||
from storage import get_storage_driver
|
||||
from util.config.validators import BaseValidator, ConfigValidationException
|
||||
|
||||
|
@ -36,7 +36,8 @@ def _get_storage_providers(config):
|
|||
|
||||
try:
|
||||
for name, parameters in storage_config.items():
|
||||
drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters))
|
||||
driver = get_storage_driver(None, None, None, config_provider, ip_resolver, parameters)
|
||||
drivers[name] = (parameters[0], driver)
|
||||
except TypeError:
|
||||
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)
|
||||
|
||||
|
|
|
@ -1,67 +1,50 @@
|
|||
import logging
|
||||
import time
|
||||
import json
|
||||
|
||||
from cachetools import lru_cache
|
||||
from cachetools import ttl_cache, lru_cache
|
||||
from collections import namedtuple, defaultdict
|
||||
from netaddr import IPNetwork, IPAddress, IPSet, AddrFormatError
|
||||
from threading import Thread
|
||||
|
||||
import geoip2.database
|
||||
import geoip2.errors
|
||||
|
||||
_AWS_IP_RANGES_URL = 'https://ip-ranges.amazonaws.com/ip-ranges.json'
|
||||
_UPDATE_TIME = 60 * 60 * 24
|
||||
_RETRY_TIME = 60 * 60 * 5
|
||||
|
||||
ResolvedLocation = namedtuple('ResolvedLocation', ['provider', 'region', 'service', 'sync_token'])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IPResolver(Thread):
|
||||
def __init__(self, app, client=None, *args, **kwargs):
|
||||
super(IPResolver, self).__init__(*args, **kwargs)
|
||||
self.daemon = True
|
||||
|
||||
class IPResolver(object):
|
||||
def __init__(self, app, *args, **kwargs):
|
||||
self.app = app
|
||||
self.client = client or app.config['HTTPCLIENT']
|
||||
|
||||
self.location_function = None
|
||||
self.sync_token = None
|
||||
|
||||
self.geoip_db = geoip2.database.Reader('util/ipresolver/GeoLite2-Country.mmdb')
|
||||
|
||||
def resolve_ip(self, ip_address):
|
||||
""" Attempts to return resolved information about the specified IP Address. If such an attempt fails,
|
||||
returns None.
|
||||
"""
|
||||
location_function = self.location_function
|
||||
location_function = self._get_location_function()
|
||||
if not ip_address or not location_function:
|
||||
return None
|
||||
|
||||
return location_function(ip_address)
|
||||
|
||||
def _update_aws_ip_range(self):
|
||||
logger.debug('Starting download of AWS IP Range table from %s', _AWS_IP_RANGES_URL)
|
||||
@ttl_cache(maxsize=1, ttl=600)
|
||||
def _get_location_function(self):
|
||||
try:
|
||||
response = self.client.get(_AWS_IP_RANGES_URL)
|
||||
if response.status_code / 100 != 2:
|
||||
logger.error('Non-200 response (%s) for AWS IP Range table request', response.status_code)
|
||||
return False
|
||||
except:
|
||||
logger.exception('Could not download AWS IP range table')
|
||||
return False
|
||||
with open('util/ipresolver/ip-ranges.json', 'r') as f:
|
||||
ip_range_json = json.loads(f.read())
|
||||
except IOError:
|
||||
logger.exception('Could not load IP Ranges')
|
||||
return None
|
||||
except ValueError:
|
||||
logger.exception('Could not load IP Ranges')
|
||||
return None
|
||||
except TypeError:
|
||||
logger.exception('Could not load IP Ranges')
|
||||
return None
|
||||
|
||||
# Check if the sync token is the same. If so, no updates are necessary.
|
||||
if self.sync_token and response.json()['syncToken'] == self.sync_token:
|
||||
logger.debug('No updates necessary')
|
||||
return True
|
||||
|
||||
# Otherwise, update the range lookup function.
|
||||
all_amazon, regions, services = IPResolver._parse_amazon_ranges(response.json())
|
||||
self.sync_token = response.json()['syncToken']
|
||||
self.location_function = IPResolver._build_location_function(self.sync_token, all_amazon, regions, services, self.geoip_db)
|
||||
logger.debug('Successfully updated AWS IP range table with sync token: %s', self.sync_token)
|
||||
return True
|
||||
sync_token = ip_range_json['syncToken']
|
||||
all_amazon, regions, services = IPResolver._parse_amazon_ranges(ip_range_json)
|
||||
return IPResolver._build_location_function(sync_token, all_amazon, regions, services, self.geoip_db)
|
||||
|
||||
@staticmethod
|
||||
def _build_location_function(sync_token, all_amazon, regions, country, country_db):
|
||||
|
@ -111,14 +94,3 @@ class IPResolver(Thread):
|
|||
services[service].add(cidr)
|
||||
|
||||
return all_amazon, regions, services
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
logger.debug('Updating AWS IP database')
|
||||
if not self._update_aws_ip_range():
|
||||
logger.debug('Failed; sleeping for %s seconds', _RETRY_TIME)
|
||||
time.sleep(_RETRY_TIME)
|
||||
continue
|
||||
|
||||
logger.debug('Success; sleeping for %s seconds', _UPDATE_TIME)
|
||||
time.sleep(_UPDATE_TIME)
|
||||
|
|
0
util/ipresolver/test/__init__.py
Normal file
0
util/ipresolver/test/__init__.py
Normal file
|
@ -17,6 +17,10 @@ def http_client():
|
|||
sess.mount('https://', adapter)
|
||||
return sess
|
||||
|
||||
@pytest.fixture()
|
||||
def test_aws_ip():
|
||||
return '10.0.0.1'
|
||||
|
||||
@pytest.fixture()
|
||||
def aws_ip_range_handler():
|
||||
@urlmatch(netloc=r'ip-ranges.amazonaws.com')
|
||||
|
@ -35,16 +39,16 @@ def aws_ip_range_handler():
|
|||
|
||||
return handler
|
||||
|
||||
def test_unstarted(app, http_client):
|
||||
def test_unstarted(app, test_aws_ip, http_client):
|
||||
ipresolver = IPResolver(app, client=http_client)
|
||||
assert ipresolver.resolve_ip('10.0.0.1') is None
|
||||
assert ipresolver.resolve_ip(test_aws_ip) is None
|
||||
|
||||
def test_resolved(aws_ip_range_handler, app, http_client):
|
||||
def test_resolved(aws_ip_range_handler, test_aws_ip, app, http_client):
|
||||
with HTTMock(aws_ip_range_handler):
|
||||
ipresolver = IPResolver(app, client=http_client)
|
||||
assert ipresolver._update_aws_ip_range()
|
||||
|
||||
assert ipresolver.resolve_ip('10.0.0.1') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip(test_aws_ip) == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('10.0.0.2') == ResolvedLocation(provider='aws', region=u'GLOBAL', service=None, sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('1.2.3.4') == ResolvedLocation(provider='internet', region=u'NA', service=u'US', sync_token=123456789)
|
||||
assert ipresolver.resolve_ip('127.0.0.1') == ResolvedLocation(provider='internet', region=None, service=None, sync_token=123456789)
|
||||
|
|
Reference in a new issue