With this change, if all entitlements are valid, we sort to show the entitlement that will expire the farthest in the future, as that defines the point at which the user must act before the license becomes invalid.
		
			
				
	
	
		
			412 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			412 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import json
 | |
| import logging
 | |
| import multiprocessing
 | |
| import time
 | |
| 
 | |
| from ctypes import c_bool
 | |
| from datetime import datetime, timedelta
 | |
| from threading import Thread
 | |
| from functools import total_ordering
 | |
| from enum import Enum, IntEnum
 | |
| from collections import namedtuple
 | |
| 
 | |
| from cryptography.hazmat.backends import default_backend
 | |
| from cryptography.hazmat.primitives.serialization import load_pem_public_key
 | |
| from dateutil import parser
 | |
| from flask import make_response
 | |
| 
 | |
| import jwt
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| TRIAL_GRACE_PERIOD = timedelta(days=7)    # 1 week
 | |
| MONTHLY_GRACE_PERIOD = timedelta(days=335) # 11 months
 | |
| YEARLY_GRACE_PERIOD = timedelta(days=90)  # 3 months
 | |
| 
 | |
| LICENSE_SOON_DELTA = timedelta(days=7) # 1 week
 | |
| 
 | |
| LICENSE_FILENAME = 'license'
 | |
| 
 | |
| QUAY_ENTITLEMENT = 'software.quay'
 | |
| QUAY_DEPLOYMENTS_ENTITLEMENT = 'software.quay.deployments'
 | |
| 
 | |
| 
 | |
| 
 | |
| class LicenseDecodeError(Exception):
 | |
|   """ Exception raised if the license could not be read, decoded or has expired. """
 | |
|   pass
 | |
| 
 | |
| 
 | |
| def _get_date(decoded, field, default_date=datetime.min):
 | |
|   """ Retrieves the encoded date found at the given field under the decoded license block. """
 | |
|   date_str = decoded.get(field)
 | |
|   return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date
 | |
| 
 | |
| 
 | |
| @total_ordering
 | |
| class Entitlement(object):
 | |
|   """ An entitlement is a specific piece of software or functionality granted
 | |
|       by a license. It has an expiration date, as well as the count of the
 | |
|       things being granted. Entitlements are orderable by their counts.
 | |
|   """
 | |
|   def __init__(self, entitlement_name, count, product_name, expiration):
 | |
|     self.name = entitlement_name
 | |
|     self.count = count
 | |
|     self.product_name = product_name
 | |
|     self.expiration = expiration
 | |
| 
 | |
|   def __lt__(self, rhs):
 | |
|     return self.count < rhs.count
 | |
| 
 | |
|   def __repr__(self):
 | |
|     return str(dict(
 | |
|       name=self.name,
 | |
|       count=self.count,
 | |
|       product_name=self.product_name,
 | |
|       expiration=repr(self.expiration),
 | |
|     ))
 | |
| 
 | |
|   def as_dict(self, for_private=False):
 | |
|     data = {
 | |
|       'name': self.name,
 | |
|     }
 | |
| 
 | |
|     if for_private:
 | |
|       data.update({
 | |
|         'count': self.count,
 | |
|         'product_name': self.product_name,
 | |
|         'expiration': self.expiration.as_dict(for_private=True),
 | |
|       })
 | |
| 
 | |
|     return data
 | |
| 
 | |
| class ExpirationType(Enum):
 | |
|   """ An enum which represents the different possible types of expirations. If
 | |
|       you posess an expired enum, you can use this to figure out at what level
 | |
|       the expiration was most restrictive.
 | |
|   """
 | |
|   license_wide = 'License Wide Expiration'
 | |
|   trial_only = 'Trial Only Expiration'
 | |
|   in_trial = 'In-Trial Expiration'
 | |
|   monthly = 'Monthly Subscription Expiration'
 | |
|   yearly = 'Yearly Subscription Expiration'
 | |
| 
 | |
| 
 | |
| @total_ordering
 | |
| class Expiration(object):
 | |
|   """ An Expiration is an orderable representation of an expiration date and a
 | |
|       grace period. If you sort Expiration objects, they will be sorted by the
 | |
|       actual cutoff date, which is the combination of the expiration date and
 | |
|       the grace period.
 | |
|   """
 | |
|   def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)):
 | |
|     self.expiration_type = expiration_type
 | |
|     self.expiration_date = exp_date
 | |
|     self.grace_period = grace_period
 | |
| 
 | |
|   @property
 | |
|   def expires_at(self):
 | |
|     return self.expiration_date + self.grace_period
 | |
| 
 | |
|   def is_expired(self, now):
 | |
|     """ Check if the current object should already be considered expired when
 | |
|         compared with the passed in datetime object.
 | |
|     """
 | |
|     return self.expires_at < now
 | |
| 
 | |
|   def __lt__(self, rhs):
 | |
|     return self.expires_at < rhs.expires_at
 | |
| 
 | |
|   def __repr__(self):
 | |
|     return str(dict(
 | |
|       expiration_type=repr(self.expiration_type),
 | |
|       expiration_date=repr(self.expiration_date),
 | |
|       grace_period=repr(self.grace_period),
 | |
|     ))
 | |
| 
 | |
|   def as_dict(self, for_private=False):
 | |
|     data = {
 | |
|       'expiration_type': str(self.expiration_type),
 | |
|     }
 | |
| 
 | |
|     if for_private:
 | |
|       data.update({
 | |
|         'expiration_date': str(self.expiration_date),
 | |
|         'grace_period': str(self.grace_period),
 | |
|       })
 | |
| 
 | |
|     return data
 | |
| 
 | |
| 
 | |
| class EntitlementStatus(IntEnum):
 | |
|   """ An EntitlementStatus represent the current effectiveness of an
 | |
|       Entitlement when compared with its corresponding requirement. As an
 | |
|       example, if the software requires 9 items, and the Entitlement only
 | |
|       provides for 7, you would use an insufficient_count status.
 | |
|   """
 | |
|   met = 0
 | |
|   expired = 1
 | |
|   insufficient_count = 2
 | |
|   no_matching = 3
 | |
| 
 | |
| 
 | |
| @total_ordering
 | |
| class EntitlementValidationResult(object):
 | |
|   """ An EntitlementValidationResult encodes the combination of a specific
 | |
|       entitlement and the software requirement which caused it to be examined.
 | |
|       They are orderable by the value of the EntitlementStatus enum, and will
 | |
|       in general be sorted by most to least satisfiable status type.
 | |
|   """
 | |
|   def __init__(self, requirement, created_at, entitlement=None):
 | |
|     self.requirement = requirement
 | |
|     self._created_at = created_at
 | |
|     self.entitlement = entitlement
 | |
| 
 | |
|   def get_status(self):
 | |
|     """ Returns the EntitlementStatus when comparing the specified Entitlement
 | |
|         with the corresponding requirement.
 | |
|     """
 | |
|     if self.entitlement is not None:
 | |
|       if self.entitlement.expiration.is_expired(self._created_at):
 | |
|         return EntitlementStatus.expired
 | |
| 
 | |
|       if self.entitlement.count < self.requirement.count:
 | |
|         return EntitlementStatus.insufficient_count
 | |
| 
 | |
|       return EntitlementStatus.met
 | |
| 
 | |
|     return EntitlementStatus.no_matching
 | |
| 
 | |
|   def is_met(self):
 | |
|     """ Returns whether this specific EntitlementValidationResult meets all
 | |
|         of the criteria for being sufficient, including unexpired (or in the
 | |
|         grace period), and with a sufficient count.
 | |
|     """
 | |
|     return self.get_status() == EntitlementStatus.met
 | |
| 
 | |
|   def __lt__(self, rhs):
 | |
|     # If this result has the same status as another, return the result with an expiration date
 | |
|     # further in the future, as it will be more relevant. The results may expire, but so long as
 | |
|     # this result is valid, so will the entitlement.
 | |
|     if self.get_status() == rhs.get_status():
 | |
|       return (self.entitlement.expiration.expiration_date >
 | |
|               rhs.entitlement.expiration.expiration_date)
 | |
| 
 | |
|     # Otherwise, sort lexically by status.
 | |
|     return self.get_status() < rhs.get_status()
 | |
| 
 | |
|   def __repr__(self):
 | |
|     return str(dict(
 | |
|       requirement=repr(self.requirement),
 | |
|       created_at=repr(self._created_at),
 | |
|       entitlement=repr(self.entitlement),
 | |
|     ))
 | |
| 
 | |
|   def as_dict(self, for_private=False):
 | |
|     def req_view():
 | |
|       return {
 | |
|         'name': self.requirement.name,
 | |
|         'count': self.requirement.count,
 | |
|       }
 | |
| 
 | |
|     data = {
 | |
|       'requirement': req_view(),
 | |
|       'status': str(self.get_status()),
 | |
|     }
 | |
| 
 | |
|     if self.entitlement is not None:
 | |
|       data['entitlement'] = self.entitlement.as_dict(for_private=for_private)
 | |
| 
 | |
|     return data
 | |
| 
 | |
| 
 | |
| class License(object):
 | |
|   """ License represents a fully decoded and validated (but potentially expired) license. """
 | |
|   def __init__(self, decoded):
 | |
|     self.decoded = decoded
 | |
| 
 | |
|   def validate_entitlement_requirement(self, entitlement_req, check_time):
 | |
|     all_active_entitlements = list(self._find_entitlements(entitlement_req.name))
 | |
| 
 | |
|     if len(all_active_entitlements) == 0:
 | |
|       return EntitlementValidationResult(entitlement_req, check_time)
 | |
| 
 | |
|     entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent)
 | |
|                            for ent in all_active_entitlements]
 | |
|     entitlement_results.sort()
 | |
|     return entitlement_results[0]
 | |
| 
 | |
|   def _find_entitlements(self, entitlement_name):
 | |
|     license_expiration = Expiration(
 | |
|       ExpirationType.license_wide,
 | |
|       _get_date(self.decoded, 'expirationDate'),
 | |
|     )
 | |
| 
 | |
|     for sub in self.decoded.get('subscriptions', {}).values():
 | |
|       entitlement_count = sub.get('entitlements', {}).get(entitlement_name)
 | |
| 
 | |
|       if entitlement_count is not None:
 | |
|         entitlement_expiration = min(self._sub_expiration(sub), license_expiration)
 | |
|         yield Entitlement(
 | |
|           entitlement_name,
 | |
|           entitlement_count,
 | |
|           sub.get('productName', 'unknown'),
 | |
|           entitlement_expiration,
 | |
|         )
 | |
| 
 | |
|   @staticmethod
 | |
|   def _sub_expiration(subscription):
 | |
|     # A trial license has its own end logic, and uses the trialEnd property
 | |
|     if subscription.get('trialOnly', False):
 | |
|       trial_expiration = Expiration(
 | |
|         ExpirationType.trial_only,
 | |
|         _get_date(subscription, 'trialEnd'),
 | |
|         TRIAL_GRACE_PERIOD,
 | |
|       )
 | |
|       return trial_expiration
 | |
| 
 | |
|     # From here we always use the serviceEnd
 | |
|     service_end = _get_date(subscription, 'serviceEnd')
 | |
| 
 | |
|     if subscription.get('inTrial', False):
 | |
|       return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD)
 | |
| 
 | |
|     if subscription.get('durationPeriod') == 'yearly':
 | |
|       return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD)
 | |
| 
 | |
|     # We assume monthly license unless specified otherwise
 | |
|     return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD)
 | |
| 
 | |
|   def validate(self, config):
 | |
|     """ Returns a list of EntitlementValidationResult objects, one per requirement.
 | |
|     """
 | |
|     requirements = _gen_entitlement_requirements(config)
 | |
|     now = datetime.now()
 | |
|     return [self.validate_entitlement_requirement(req, now) for req in requirements]
 | |
| 
 | |
| 
 | |
| _PROD_LICENSE_PUBLIC_KEY_DATA = """
 | |
| -----BEGIN PUBLIC KEY-----
 | |
| MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR
 | |
| e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA
 | |
| 1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv
 | |
| busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt
 | |
| EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4
 | |
| ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN
 | |
| qwIDAQAB
 | |
| -----END PUBLIC KEY-----
 | |
| """
 | |
| _PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA,
 | |
|                                                backend=default_backend())
 | |
| 
 | |
| def decode_license(license_contents, public_key_instance=None):
 | |
|   """ Decodes the specified license contents, returning the decoded license. """
 | |
|   license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY
 | |
|   try:
 | |
|     jwt_data = jwt.decode(license_contents, key=license_public_key)
 | |
|   except jwt.exceptions.DecodeError as de:
 | |
|     logger.exception('Could not decode license file')
 | |
|     raise LicenseDecodeError('Could not decode license found: %s' % de.message)
 | |
| 
 | |
|   try:
 | |
|     decoded = json.loads(jwt_data.get('license', '{}'))
 | |
|   except ValueError as ve:
 | |
|     logger.exception('Could not decode license file')
 | |
|     raise LicenseDecodeError('Could not decode license found: %s' % ve.message)
 | |
| 
 | |
|   return License(decoded)
 | |
| 
 | |
| 
 | |
| LICENSE_VALIDATION_INTERVAL = 3600 # seconds
 | |
| LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
 | |
| 
 | |
| 
 | |
| EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count'])
 | |
| 
 | |
| 
 | |
| def _gen_entitlement_requirements(config_obj):
 | |
|   config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', []))
 | |
|   return [
 | |
|     EntitlementRequirement(QUAY_ENTITLEMENT, 1),
 | |
|     EntitlementRequirement(QUAY_DEPLOYMENTS_ENTITLEMENT, config_regions),
 | |
|   ]
 | |
| 
 | |
| 
 | |
| class LicenseValidator(Thread):
 | |
|   """
 | |
|   LicenseValidator is a thread that asynchronously reloads and validates license files.
 | |
| 
 | |
|   This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
 | |
|   synchronization primitive.
 | |
|   """
 | |
|   def __init__(self, config_provider, *args, **kwargs):
 | |
|     config = config_provider.get_config() or {}
 | |
| 
 | |
|     self._config_provider = config_provider
 | |
|     self._entitlement_requirements = _gen_entitlement_requirements(config)
 | |
| 
 | |
|     # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
 | |
|     self._license_is_insufficient = multiprocessing.Value(c_bool, True)
 | |
|     self._license_expiring_soon = multiprocessing.Value(c_bool, True)
 | |
| 
 | |
|     super(LicenseValidator, self).__init__(*args, **kwargs)
 | |
|     self.daemon = True
 | |
| 
 | |
|   @property
 | |
|   def expiring_soon(self):
 | |
|     """ Returns whether the license will be expiring soon (a week from now). """
 | |
|     return self._license_expiring_soon.value
 | |
| 
 | |
|   @property
 | |
|   def insufficient(self):
 | |
|     return self._license_is_insufficient.value
 | |
| 
 | |
|   def compute_license_sufficiency(self):
 | |
|     """ Check whether all of our requirements are met, and set the status of
 | |
|         the result of the check, which will be used to disable the software.
 | |
|         Returns True if any requirements are not met, and False if all are met.
 | |
|     """
 | |
|     try:
 | |
|       current_license = self._config_provider.get_license()
 | |
|       now = datetime.now()
 | |
|       soon = now + LICENSE_SOON_DELTA
 | |
|       any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met()
 | |
|                             for req in self._entitlement_requirements)
 | |
|       soon_invalid = not all(current_license.validate_entitlement_requirement(req, soon).is_met()
 | |
|                              for req in self._entitlement_requirements)
 | |
|       logger.debug('updating license license_is_insufficient to %s', any_invalid)
 | |
|       logger.debug('updating license license_expiring_soon to %s', soon_invalid)
 | |
|     except (IOError, LicenseDecodeError):
 | |
|       logger.exception('failed to validate license')
 | |
|       any_invalid = True
 | |
|       soon_invalid = False
 | |
| 
 | |
|     self._license_is_insufficient.value = any_invalid
 | |
|     self._license_expiring_soon.value = soon_invalid
 | |
|     return any_invalid
 | |
| 
 | |
|   def run(self):
 | |
|     logger.debug('Starting license validation thread')
 | |
|     while True:
 | |
|       invalid = self.compute_license_sufficiency()
 | |
|       sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL
 | |
|       logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
 | |
|       time.sleep(sleep_time)
 | |
| 
 | |
|   def enforce_license_before_request(self, blueprint, response_func=None):
 | |
|     """
 | |
|     Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
 | |
|     license has become invalid, the client will receive a HTTP 402 response.
 | |
|     """
 | |
|     if response_func is None:
 | |
|       def _response_func():
 | |
|         return make_response('License is insufficient.', 402)
 | |
|       response_func = _response_func
 | |
| 
 | |
|     def _enforce_license():
 | |
|       if self.insufficient:
 | |
|         logger.debug('blocked interaction due to insufficient license')
 | |
|         return response_func()
 | |
|     blueprint.before_request(_enforce_license)
 |