import json import logging import multiprocessing import time from ctypes import c_bool from datetime import datetime, timedelta from threading import Thread from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_public_key from dateutil import parser from flask import abort import jwt logger = logging.getLogger(__name__) TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week MONTHLY_GRACE_PERIOD = timedelta(30, 0) # 1 month YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months LICENSE_PRODUCT_NAME = "quay-enterprise" LICENSE_FILENAME = 'license' class LicenseError(Exception): """ Exception raised if the license could not be read, decoded or has expired. """ pass class LicenseDecodeError(LicenseError): """ Exception raised if the license could not be decoded. """ pass class LicenseValidationError(LicenseError): """ Exception raised if the license could not be validated. """ pass def _get_date(decoded, field): """ Retrieves the encoded date found at the given field under the decoded license block. """ date_str = decoded.get(field) if date_str: return parser.parse(date_str).replace(tzinfo=None) return datetime.now() - timedelta(days=2) class License(object): """ License represents a fully decoded and validated (but potentially expired) license. """ def __init__(self, decoded): self.decoded = decoded @property def subscription(self): """ Returns the Quay Enterprise subscription, if any. """ for sub in self.decoded.get('subscriptions', {}).values(): if sub.get('productName') == LICENSE_PRODUCT_NAME: return sub return None @property def is_expired(self): return self._get_expired(datetime.now()) def validate(self, config): """ Validates the license and all its entitlements against the given config. """ # Check that the license has not expired. if self.is_expired: raise LicenseValidationError('License has expired') # Check the maximum number of replication regions. max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1) config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', [])) if max_regions != -1 and config_regions > max_regions: msg = '{} regions configured, but license file allows up to {}'.format(config_regions, max_regions) raise LicenseValidationError(msg) def _get_expired(self, compare_date): # Check if the license overall has expired. expiration_date = _get_date(self.decoded, 'expirationDate') if expiration_date <= compare_date: logger.debug('License expired on %s', expiration_date) return True # Check for any QE subscriptions. sub = self.subscription if sub is None: return True # Check for a trial-only license. if sub.get('trialOnly', False): trial_end_date = _get_date(sub, 'trialEnd') logger.debug('Trial-only license expires on %s', trial_end_date) return trial_end_date <= (compare_date - TRIAL_GRACE_PERIOD) # Check for a normal license that is in trial. service_end_date = _get_date(sub, 'serviceEnd') if sub.get('inTrial', False): # If the subscription is in a trial, but not a trial only # subscription, give 7 days after trial end to update license # to one which has been paid (they've put in a credit card and it # might auto convert, so we could assume it will auto-renew) logger.debug('In-trial license expires on %s', service_end_date) return service_end_date <= (compare_date - TRIAL_GRACE_PERIOD) # Otherwise, check the service expiration. duration_period = sub.get('durationPeriod', 'months') # If the subscription is monthly, give 3 months grace period if duration_period == "months": logger.debug('Monthly license expires on %s', service_end_date) return service_end_date <= (compare_date - MONTHLY_GRACE_PERIOD) if duration_period == "years": logger.debug('Yearly license expires on %s', service_end_date) return service_end_date <= (compare_date - YEARLY_GRACE_PERIOD) return True _PROD_LICENSE_PUBLIC_KEY_DATA = """ -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA 1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4 ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN qwIDAQAB -----END PUBLIC KEY----- """ _PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA, backend=default_backend()) def decode_license(license_contents, public_key_instance=None): """ Decodes the specified license contents, returning the decoded license. """ license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY try: jwt_data = jwt.decode(license_contents, key=license_public_key) except jwt.exceptions.DecodeError as de: logger.exception('Could not decode license file') raise LicenseDecodeError('Could not decode license found: %s' % de.message) try: decoded = json.loads(jwt_data.get('license', '{}')) except ValueError as ve: logger.exception('Could not decode license file') raise LicenseDecodeError('Could not decode license found: %s' % ve.message) return License(decoded) LICENSE_VALIDATION_INTERVAL = 3600 # seconds LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds class LicenseValidator(Thread): """ LicenseValidator is a thread that asynchronously reloads and validates license files. This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a synchronization primitive. """ def __init__(self, license_path): self._license_path = license_path # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that. self._license_is_expired = multiprocessing.Value(c_bool, True) super(LicenseValidator, self).__init__() @property def expired(self): return self._license_is_expired.value def _check_expiration(self): try: with open(self._license_path) as f: current_license = decode_license(f.read()) logger.debug('updating license expiration to %s', current_license.is_expired) self._license_is_expired.value = current_license.is_expired return current_license.is_expired except (IOError, LicenseError): logger.exception('failed to validate license') self._license_is_expired.value = True return True def run(self): logger.debug('Starting license validation thread') while True: expired = self._check_expiration() sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if expired else LICENSE_VALIDATION_INTERVAL logger.debug('waiting %d seconds before retrying to validate license', sleep_time) time.sleep(sleep_time) def enforce_license_before_request(license_validator, blueprint): """ Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the license has become invalid, the client will receive a HTTP 402 response. """ def _enforce_license(): if license_validator.expired: abort(402) blueprint.before_request(_enforce_license)