import json import logging import multiprocessing import time from ctypes import c_bool from datetime import datetime, timedelta from threading import Thread from functools import total_ordering from enum import Enum, IntEnum from collections import namedtuple from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_public_key from dateutil import parser from flask import make_response import jwt logger = logging.getLogger(__name__) TRIAL_GRACE_PERIOD = timedelta(days=7) # 1 week MONTHLY_GRACE_PERIOD = timedelta(days=335) # 11 months YEARLY_GRACE_PERIOD = timedelta(days=90) # 3 months LICENSE_SOON_DELTA = timedelta(days=7) # 1 week LICENSE_FILENAME = 'license' QUAY_ENTITLEMENT = 'software.quay' QUAY_DEPLOYMENTS_ENTITLEMENT = 'software.quay.deployments' class LicenseDecodeError(Exception): """ Exception raised if the license could not be read, decoded or has expired. """ pass def _get_date(decoded, field, default_date=datetime.min): """ Retrieves the encoded date found at the given field under the decoded license block. """ date_str = decoded.get(field) return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date @total_ordering class Entitlement(object): """ An entitlement is a specific piece of software or functionality granted by a license. It has an expiration date, as well as the count of the things being granted. Entitlements are orderable by their counts. """ def __init__(self, entitlement_name, count, product_name, expiration): self.name = entitlement_name self.count = count self.product_name = product_name self.expiration = expiration def __lt__(self, rhs): return self.count < rhs.count def __repr__(self): return str(dict( name=self.name, count=self.count, product_name=self.product_name, expiration=repr(self.expiration), )) def as_dict(self, for_private=False): data = { 'name': self.name, } if for_private: data.update({ 'count': self.count, 'product_name': self.product_name, 'expiration': self.expiration.as_dict(for_private=True), }) return data class ExpirationType(Enum): """ An enum which represents the different possible types of expirations. If you posess an expired enum, you can use this to figure out at what level the expiration was most restrictive. """ license_wide = 'License Wide Expiration' trial_only = 'Trial Only Expiration' in_trial = 'In-Trial Expiration' monthly = 'Monthly Subscription Expiration' yearly = 'Yearly Subscription Expiration' @total_ordering class Expiration(object): """ An Expiration is an orderable representation of an expiration date and a grace period. If you sort Expiration objects, they will be sorted by the actual cutoff date, which is the combination of the expiration date and the grace period. """ def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)): self.expiration_type = expiration_type self.expiration_date = exp_date self.grace_period = grace_period @property def expires_at(self): return self.expiration_date + self.grace_period def is_expired(self, now): """ Check if the current object should already be considered expired when compared with the passed in datetime object. """ return self.expires_at < now def __lt__(self, rhs): return self.expires_at < rhs.expires_at def __repr__(self): return str(dict( expiration_type=repr(self.expiration_type), expiration_date=repr(self.expiration_date), grace_period=repr(self.grace_period), )) def as_dict(self, for_private=False): data = { 'expiration_type': str(self.expiration_type), } if for_private: data.update({ 'expiration_date': str(self.expiration_date), 'grace_period': str(self.grace_period), }) return data class EntitlementStatus(IntEnum): """ An EntitlementStatus represent the current effectiveness of an Entitlement when compared with its corresponding requirement. As an example, if the software requires 9 items, and the Entitlement only provides for 7, you would use an insufficient_count status. """ met = 0 expired = 1 insufficient_count = 2 no_matching = 3 @total_ordering class EntitlementValidationResult(object): """ An EntitlementValidationResult encodes the combination of a specific entitlement and the software requirement which caused it to be examined. They are orderable by the value of the EntitlementStatus enum, and will in general be sorted by most to least satisfiable status type. """ def __init__(self, requirement, created_at, entitlement=None): self.requirement = requirement self._created_at = created_at self.entitlement = entitlement def get_status(self): """ Returns the EntitlementStatus when comparing the specified Entitlement with the corresponding requirement. """ if self.entitlement is not None: if self.entitlement.expiration.is_expired(self._created_at): return EntitlementStatus.expired if self.entitlement.count < self.requirement.count: return EntitlementStatus.insufficient_count return EntitlementStatus.met return EntitlementStatus.no_matching def is_met(self): """ Returns whether this specific EntitlementValidationResult meets all of the criteria for being sufficient, including unexpired (or in the grace period), and with a sufficient count. """ return self.get_status() == EntitlementStatus.met def __lt__(self, rhs): # If this result has the same status as another, return the result with an expiration date # further in the future, as it will be more relevant. The results may expire, but so long as # this result is valid, so will the entitlement. if self.get_status() == rhs.get_status(): return (self.entitlement.expiration.expiration_date > rhs.entitlement.expiration.expiration_date) # Otherwise, sort lexically by status. return self.get_status() < rhs.get_status() def __repr__(self): return str(dict( requirement=repr(self.requirement), created_at=repr(self._created_at), entitlement=repr(self.entitlement), )) def description(self): msg = '%s requires %s: has status %s' return msg % (self.requirement.name, self.requirement.count, self.get_status()) def as_dict(self, for_private=False): def req_view(): return { 'name': self.requirement.name, 'count': self.requirement.count, } data = { 'requirement': req_view(), 'status': str(self.get_status()), } if self.entitlement is not None: data['entitlement'] = self.entitlement.as_dict(for_private=for_private) return data class License(object): """ License represents a fully decoded and validated (but potentially expired) license. """ def __init__(self, decoded): self.decoded = decoded def validate_entitlement_requirement(self, entitlement_req, check_time): all_active_entitlements = list(self._find_entitlements(entitlement_req.name)) if len(all_active_entitlements) == 0: return EntitlementValidationResult(entitlement_req, check_time) entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent) for ent in all_active_entitlements] entitlement_results.sort() return entitlement_results[0] def _find_entitlements(self, entitlement_name): license_expiration = Expiration( ExpirationType.license_wide, _get_date(self.decoded, 'expirationDate'), ) for sub in self.decoded.get('subscriptions', {}).values(): entitlement_count = sub.get('entitlements', {}).get(entitlement_name) if entitlement_count is not None: entitlement_expiration = min(self._sub_expiration(sub), license_expiration) yield Entitlement( entitlement_name, entitlement_count, sub.get('productName', 'unknown'), entitlement_expiration, ) @staticmethod def _sub_expiration(subscription): # A trial license has its own end logic, and uses the trialEnd property if subscription.get('trialOnly', False): trial_expiration = Expiration( ExpirationType.trial_only, _get_date(subscription, 'trialEnd'), TRIAL_GRACE_PERIOD, ) return trial_expiration # From here we always use the serviceEnd service_end = _get_date(subscription, 'serviceEnd') if subscription.get('inTrial', False): return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD) if subscription.get('durationPeriod') == 'yearly': return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD) # We assume monthly license unless specified otherwise return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD) def validate(self, config): """ Returns a list of EntitlementValidationResult objects, one per requirement. """ requirements = _gen_entitlement_requirements(config) now = datetime.now() return [self.validate_entitlement_requirement(req, now) for req in requirements] _PROD_LICENSE_PUBLIC_KEY_DATA = """ -----BEGIN PUBLIC KEY----- MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA 1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4 ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN qwIDAQAB -----END PUBLIC KEY----- """ _PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA, backend=default_backend()) def decode_license(license_contents, public_key_instance=None): """ Decodes the specified license contents, returning the decoded license. """ license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY try: jwt_data = jwt.decode(license_contents, key=license_public_key) except jwt.exceptions.DecodeError as de: logger.exception('Could not decode license file') raise LicenseDecodeError('Could not decode license found: %s' % de.message) try: decoded = json.loads(jwt_data.get('license', '{}')) except ValueError as ve: logger.exception('Could not decode license file') raise LicenseDecodeError('Could not decode license found: %s' % ve.message) return License(decoded) LICENSE_VALIDATION_INTERVAL = 3600 # seconds LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count']) def _gen_entitlement_requirements(config_obj): config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', [])) return [ EntitlementRequirement(QUAY_ENTITLEMENT, 1), EntitlementRequirement(QUAY_DEPLOYMENTS_ENTITLEMENT, config_regions), ] class LicenseValidator(Thread): """ LicenseValidator is a thread that asynchronously reloads and validates license files. This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a synchronization primitive. """ def __init__(self, config_provider, *args, **kwargs): config = config_provider.get_config() or {} self._config_provider = config_provider self._entitlement_requirements = _gen_entitlement_requirements(config) # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that. self._license_is_insufficient = multiprocessing.Value(c_bool, True) self._license_expiring_soon = multiprocessing.Value(c_bool, True) super(LicenseValidator, self).__init__(*args, **kwargs) self.daemon = True @property def expiring_soon(self): """ Returns whether the license will be expiring soon (a week from now). """ return self._license_expiring_soon.value @property def insufficient(self): return self._license_is_insufficient.value def compute_license_sufficiency(self): """ Check whether all of our requirements are met, and set the status of the result of the check, which will be used to disable the software. Returns True if any requirements are not met, and False if all are met. """ try: current_license = self._config_provider.get_license() now = datetime.now() soon = now + LICENSE_SOON_DELTA any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met() for req in self._entitlement_requirements) soon_invalid = not all(current_license.validate_entitlement_requirement(req, soon).is_met() for req in self._entitlement_requirements) logger.debug('updating license license_is_insufficient to %s', any_invalid) logger.debug('updating license license_expiring_soon to %s', soon_invalid) except (IOError, LicenseDecodeError): logger.exception('failed to validate license') any_invalid = True soon_invalid = False self._license_is_insufficient.value = any_invalid self._license_expiring_soon.value = soon_invalid return any_invalid def run(self): logger.debug('Starting license validation thread') while True: invalid = self.compute_license_sufficiency() sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL logger.debug('waiting %d seconds before retrying to validate license', sleep_time) time.sleep(sleep_time) def enforce_license_before_request(self, blueprint, response_func=None): """ Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the license has become invalid, the client will receive a HTTP 402 response. """ if response_func is None: def _response_func(): return make_response('License is insufficient.', 402) response_func = _response_func def _enforce_license(): if self.insufficient: logger.debug('blocked interaction due to insufficient license') return response_func() blueprint.before_request(_enforce_license)