diff --git a/app.py b/app.py index dc9671001..579dd0229 100644 --- a/app.py +++ b/app.py @@ -35,6 +35,7 @@ from util.security.signing import Signer from util.security.instancekeys import InstanceKeys from util.saas.cloudwatch import start_cloudwatch_sender from util.config.provider import get_config_provider +from util.config.provider.baseprovider import SetupIncompleteException from util.config.configutil import generate_secret_key from util.config.superusermanager import SuperUserManager from util.secscan.api import SecurityScannerAPI diff --git a/test/test_license.py b/test/test_license.py new file mode 100644 index 000000000..0b041924b --- /dev/null +++ b/test/test_license.py @@ -0,0 +1,125 @@ +import unittest +import jwt + +from datetime import datetime, timedelta +from util.config.provider.license import (decode_license, LICENSE_PRODUCT_NAME, + LicenseValidationError) + +from Crypto.PublicKey import RSA + + +class TestLicense(unittest.TestCase): + def keys(self): + with open('test/data/test.pem') as f: + private_key = f.read() + + return (RSA.importKey(private_key).publickey().exportKey('PEM'), private_key) + + def create_license(self, license_data): + (public_key, private_key) = self.keys() + + # Encode the license with the JWT key. + encoded = jwt.encode(license_data, private_key, 'RS256') + + # Decode it into a license object. + return decode_license(encoded, public_key_contents=public_key) + + def get_license(self, expiration_delta=None, **kwargs): + license_data = { + 'expirationDate': str(datetime.now() + expiration_delta), + } + + if kwargs: + sub = { + 'productName': LICENSE_PRODUCT_NAME, + } + + sub['trialOnly'] = kwargs.get('trial_only', False) + sub['inTrial'] = kwargs.get('in_trial', False) + sub['entitlements'] = kwargs.get('entitlements', []) + + if 'trial_end' in kwargs: + sub['trialEnd'] = str(datetime.now() + kwargs['trial_end']) + + if 'service_end' in kwargs: + sub['serviceEnd'] = str(datetime.now() + kwargs['service_end']) + + if 'duration' in kwargs: + sub['durationPeriod'] = kwargs['duration'] + + license_data['subscriptions'] = [sub] + + decoded_license = self.create_license(license_data) + return decoded_license + + def test_license_itself_expired(self): + # License is expired. + license = self.get_license(timedelta(days=-30)) + + def test_no_qe_subscription(self): + # License is not expired, but there is no QE sub, so not valid. + license = self.get_license(timedelta(days=30)) + + def test_trial_withingrace(self): + license = self.get_license(timedelta(days=30), trial_only=True, trial_end=timedelta(days=-1)) + self.assertFalse(license.is_expired) + + def test_trial_outsidegrace(self): + license = self.get_license(timedelta(days=30), trial_only=True, trial_end=timedelta(days=-10)) + self.assertTrue(license.is_expired) + + def test_trial_intrial_withingrace(self): + license = self.get_license(timedelta(days=30), in_trial=True, service_end=timedelta(days=-1)) + self.assertFalse(license.is_expired) + + def test_trial_intrial_outsidegrace(self): + license = self.get_license(timedelta(days=30), in_trial=True, service_end=timedelta(days=-10)) + self.assertTrue(license.is_expired) + + def test_monthly_license_valid(self): + license = self.get_license(timedelta(days=30), service_end=timedelta(days=10), duration='monthly') + self.assertFalse(license.is_expired) + + def test_monthly_license_withingrace(self): + license = self.get_license(timedelta(days=30), service_end=timedelta(days=-10), duration='monthly') + self.assertFalse(license.is_expired) + + def test_monthly_license_outsidegrace(self): + license = self.get_license(timedelta(days=30), service_end=timedelta(days=-40), duration='monthly') + self.assertTrue(license.is_expired) + + def test_yearly_license_withingrace(self): + license = self.get_license(timedelta(days=30), service_end=timedelta(days=-40), duration='years') + self.assertFalse(license.is_expired) + + def test_yearly_license_outsidegrace(self): + license = self.get_license(timedelta(days=30), service_end=timedelta(days=-100), duration='years') + self.assertTrue(license.is_expired) + + def test_valid_license(self): + license = self.get_license(timedelta(days=300), service_end=timedelta(days=40), duration='years') + self.assertFalse(license.is_expired) + + def test_validate_basic_license(self): + decoded = self.get_license(timedelta(days=30), entitlements={}) + decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}]}) + + def test_validate_storage_entitlement_valid(self): + decoded = self.get_license(timedelta(days=30), entitlements={ + 'software.quay.regions': 2, + }) + + decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}]}) + + def test_validate_storage_entitlement_invalid(self): + decoded = self.get_license(timedelta(days=30), entitlements={ + 'software.quay.regions': 1, + }) + + with self.assertRaises(LicenseValidationError): + decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}, {}]}) + + +if __name__ == '__main__': + unittest.main() + diff --git a/util/config/provider/baseprovider.py b/util/config/provider/baseprovider.py index 1cf9c654f..18dd13add 100644 --- a/util/config/provider/baseprovider.py +++ b/util/config/provider/baseprovider.py @@ -1,12 +1,18 @@ import yaml import logging +from util.config.provider.license import LICENSE_FILENAME, LicenseError, decode_license + logger = logging.getLogger(__name__) class CannotWriteConfigException(Exception): """ Exception raised when the config cannot be written. """ pass +class SetupIncompleteException(Exception): + """ Exception raised when attempting to verify config that has not yet been setup. """ + pass + def import_yaml(config_obj, config_file): with open(config_file) as f: c = yaml.safe_load(f) @@ -38,6 +44,8 @@ def export_yaml(config_obj, config_file): class BaseProvider(object): """ A configuration provider helps to load, save, and handle config override in the application. """ + def __init__(self): + self.license = None @property def provider_id(self): @@ -81,4 +89,26 @@ class BaseProvider(object): """ If true, the configuration loaded into memory for the app does not match that on disk, indicating that this container requires a restart. """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def validate_license(self, config): + """ Validates that the configuration matches the license file (if any). """ + if not config.get('SETUP_COMPLETE', False): + raise SetupIncompleteException() + + with self._get_license_file() as f: + license_file_contents = f.read() + + self.license = decode_license(license_file_contents) + self.license.validate(config) + + def _get_license_file(self): + """ Returns the contents of the license file. """ + try: + return self.get_volume_file(LICENSE_FILENAME) + except IOError: + msg = 'Could not open license file. Please make sure it is in your config volume.' + raise LicenseError(msg) + + + diff --git a/util/config/provider/license.py b/util/config/provider/license.py new file mode 100644 index 000000000..829a944cc --- /dev/null +++ b/util/config/provider/license.py @@ -0,0 +1,134 @@ +import logging + +from dateutil import parser +from datetime import datetime, timedelta +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_pem_public_key + +import jwt + +logger = logging.getLogger(__name__) + +TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week +MONTHLY_GRACE_PERIOD = timedelta(30, 0) # 1 month +YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months + +LICENSE_PRODUCT_NAME = "quay-enterprise" + +class LicenseError(Exception): + """ Exception raised if the license could not be read, decoded or has expired. """ + pass + + +class LicenseDecodeError(LicenseError): + """ Exception raised if the license could not be decoded. """ + pass + + +class LicenseValidationError(LicenseError): + """ Exception raised if the license could not be validated. """ + pass + + +def _get_date(decoded, field): + """ Retrieves the encoded date found at the given field under the decoded license block. """ + date_str = decoded.get(field) + if date_str: + return parser.parse(date_str).replace(tzinfo=None) + + return datetime.now() - timedelta(days=2) + + +class License(object): + """ License represents a fully decoded and validated (but potentially expired) license. """ + def __init__(self, decoded): + self.decoded = decoded + + def validate(self, config): + """ Validates the license and all its entitlements against the given config. """ + # Check that the license has not expired. + if self.is_expired: + raise LicenseValidationError('License has expired') + + # Check the maximum number of replication regions. + max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1) + config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', [])) + if max_regions != -1 and config_regions > max_regions: + msg = '{} regions configured, but license file allows up to {}'.format(config_regions, + max_regions) + raise LicenseValidationError(msg) + + @property + def is_expired(self): + return self._get_expired(datetime.now()) + + def _get_expired(self, compare_date): + # Check if the license overall has expired. + expiration_date = _get_date(self.decoded, 'expirationDate') + if expiration_date <= compare_date: + logger.debug('License expired on %s', expiration_date) + return True + + # Check for any QE subscriptions. + for sub in self.decoded.get('subscriptions', []): + if sub.get('productName') != LICENSE_PRODUCT_NAME: + continue + + # Check for a trial-only license. + if sub.get('trialOnly', False): + trial_end_date = _get_date(sub, 'trialEnd') + logger.debug('Trial-only license expires on %s', trial_end_date) + return trial_end_date <= (compare_date - TRIAL_GRACE_PERIOD) + + # Check for a normal license that is in trial. + service_end_date = _get_date(sub, 'serviceEnd') + if sub.get('inTrial', False): + # If the subscription is in a trial, but not a trial only + # subscription, give 7 days after trial end to update license + # to one which has been paid (they've put in a credit card and it + # might auto convert, so we could assume it will auto-renew) + logger.debug('In-trial license expires on %s', service_end_date) + return service_end_date <= (compare_date - TRIAL_GRACE_PERIOD) + + # Otherwise, check the service expiration. + duration_period = sub.get('durationPeriod', 'monthly') + + # If the subscription is monthly, give 3 months grace period + if duration_period == "monthly": + logger.debug('Monthly license expires on %s', service_end_date) + return service_end_date <= (compare_date - MONTHLY_GRACE_PERIOD) + + if duration_period == "years": + logger.debug('Yearly license expires on %s', service_end_date) + return service_end_date <= (compare_date - YEARLY_GRACE_PERIOD) + + return True + + +LICENSE_FILENAME = 'license' + +_PROD_LICENSE_PUBLIC_KEY_DATA = """ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR +e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA +1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv +busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt +EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4 +ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN +qwIDAQAB +-----END PUBLIC KEY----- +""" + +def decode_license(license_contents, public_key_contents=None): + """ Decodes the specified license contents, returning the decoded license. """ + public_key_data = public_key_contents or _PROD_LICENSE_PUBLIC_KEY_DATA + license_public_key = load_pem_public_key(public_key_data, backend=default_backend()) + try: + decoded = jwt.decode(license_contents, key=license_public_key) + except jwt.exceptions.DecodeError as de: + logger.exception('Could not decode license file') + raise LicenseDecodeError('Could not decode license found: %s' % de.message) + + return License(decoded) + +