This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/util/license.py
2016-10-24 16:04:04 -04:00

389 lines
13 KiB
Python

import json
import logging
import multiprocessing
import time
from ctypes import c_bool
from datetime import datetime, timedelta
from threading import Thread
from functools import total_ordering
from enum import Enum, IntEnum
from collections import namedtuple
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from dateutil import parser
from flask import make_response
import jwt
logger = logging.getLogger(__name__)
TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week
MONTHLY_GRACE_PERIOD = timedelta(335, 0) # 11 months
YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months
LICENSE_FILENAME = 'license'
QUAY_ENTITLEMENT = 'software.quay'
QUAY_DEPLOYMENTS_ENTITLEMENT = 'software.quay.deployments'
class LicenseDecodeError(Exception):
""" Exception raised if the license could not be read, decoded or has expired. """
pass
def _get_date(decoded, field, default_date=datetime.min):
""" Retrieves the encoded date found at the given field under the decoded license block. """
date_str = decoded.get(field)
return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date
@total_ordering
class Entitlement(object):
""" An entitlement is a specific piece of software or functionality granted
by a license. It has an expiration date, as well as the count of the
things being granted. Entitlements are orderable by their counts.
"""
def __init__(self, entitlement_name, count, product_name, expiration):
self.name = entitlement_name
self.count = count
self.product_name = product_name
self.expiration = expiration
def __lt__(self, rhs):
return self.count < rhs.count
def __repr__(self):
return str(dict(
name=self.name,
count=self.count,
product_name=self.product_name,
expiration=repr(self.expiration),
))
def as_dict(self, for_private=False):
data = {
'name': self.name,
}
if for_private:
data.update({
'count': self.count,
'product_name': self.product_name,
'expiration': self.expiration.as_dict(for_private=True),
})
return data
class ExpirationType(Enum):
""" An enum which represents the different possible types of expirations. If
you posess an expired enum, you can use this to figure out at what level
the expiration was most restrictive.
"""
license_wide = 'License Wide Expiration'
trial_only = 'Trial Only Expiration'
in_trial = 'In-Trial Expiration'
monthly = 'Monthly Subscription Expiration'
yearly = 'Yearly Subscription Expiration'
@total_ordering
class Expiration(object):
""" An Expiration is an orderable representation of an expiration date and a
grace period. If you sort Expiration objects, they will be sorted by the
actual cutoff date, which is the combination of the expiration date and
the grace period.
"""
def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)):
self.expiration_type = expiration_type
self.expiration_date = exp_date
self.grace_period = grace_period
@property
def expires_at(self):
return self.expiration_date + self.grace_period
def is_expired(self, now):
""" Check if the current object should already be considered expired when
compared with the passed in datetime object.
"""
return self.expires_at < now
def __lt__(self, rhs):
return self.expires_at < rhs.expires_at
def __repr__(self):
return str(dict(
expiration_type=repr(self.expiration_type),
expiration_date=repr(self.expiration_date),
grace_period=repr(self.grace_period),
))
def as_dict(self, for_private=False):
data = {
'expiration_type': str(self.expiration_type),
}
if for_private:
data.update({
'expiration_date': str(self.expiration_date),
'grace_period': str(self.grace_period),
})
return data
class EntitlementStatus(IntEnum):
""" An EntitlementStatus represent the current effectiveness of an
Entitlement when compared with its corresponding requirement. As an
example, if the software requires 9 items, and the Entitlement only
provides for 7, you would use an insufficient_count status.
"""
met = 0
expired = 1
insufficient_count = 2
no_matching = 3
@total_ordering
class EntitlementValidationResult(object):
""" An EntitlementValidationResult encodes the combination of a specific
entitlement and the software requirement which caused it to be examined.
They are orderable by the value of the EntitlementStatus enum, and will
in general be sorted by most to least satisfiable status type.
"""
def __init__(self, requirement, created_at, entitlement=None):
self.requirement = requirement
self._created_at = created_at
self.entitlement = entitlement
def get_status(self):
""" Returns the EntitlementStatus when comparing the specified Entitlement
with the corresponding requirement.
"""
if self.entitlement is not None:
if self.entitlement.expiration.is_expired(self._created_at):
return EntitlementStatus.expired
if self.entitlement.count < self.requirement.count:
return EntitlementStatus.insufficient_count
return EntitlementStatus.met
return EntitlementStatus.no_matching
def is_met(self):
""" Returns whether this specific EntitlementValidationResult meets all
of the criteria for being sufficient, including unexpired (or in the
grace period), and with a sufficient count.
"""
return self.get_status() == EntitlementStatus.met
def __lt__(self, rhs):
return self.get_status() < rhs.get_status()
def __repr__(self):
return str(dict(
requirement=repr(self.requirement),
created_at=repr(self._created_at),
entitlement=repr(self.entitlement),
))
def as_dict(self, for_private=False):
def req_view():
return {
'name': self.requirement.name,
'count': self.requirement.count,
}
data = {
'requirement': req_view(),
'status': str(self.get_status()),
}
if self.entitlement is not None:
data['entitlement'] = self.entitlement.as_dict(for_private=for_private)
return data
class License(object):
""" License represents a fully decoded and validated (but potentially expired) license. """
def __init__(self, decoded):
self.decoded = decoded
def validate_entitlement_requirement(self, entitlement_req, check_time):
all_active_entitlements = list(self._find_entitlements(entitlement_req.name))
if len(all_active_entitlements) == 0:
return EntitlementValidationResult(entitlement_req, check_time)
entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent)
for ent in all_active_entitlements]
entitlement_results.sort()
return entitlement_results[0]
def _find_entitlements(self, entitlement_name):
license_expiration = Expiration(
ExpirationType.license_wide,
_get_date(self.decoded, 'expirationDate'),
)
for sub in self.decoded.get('subscriptions', {}).values():
entitlement_count = sub.get('entitlements', {}).get(entitlement_name)
if entitlement_count is not None:
entitlement_expiration = min(self._sub_expiration(sub), license_expiration)
yield Entitlement(
entitlement_name,
entitlement_count,
sub.get('productName', 'unknown'),
entitlement_expiration,
)
@staticmethod
def _sub_expiration(subscription):
# A trial license has its own end logic, and uses the trialEnd property
if subscription.get('trialOnly', False):
trial_expiration = Expiration(
ExpirationType.trial_only,
_get_date(subscription, 'trialEnd'),
TRIAL_GRACE_PERIOD,
)
return trial_expiration
# From here we always use the serviceEnd
service_end = _get_date(subscription, 'serviceEnd')
if subscription.get('inTrial', False):
return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD)
if subscription.get('durationPeriod') == 'yearly':
return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD)
# We assume monthly license unless specified otherwise
return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD)
def validate(self, config):
""" Returns a list of EntitlementValidationResult objects, one per requirement.
"""
requirements = _gen_entitlement_requirements(config)
now = datetime.now()
return [self.validate_entitlement_requirement(req, now) for req in requirements]
_PROD_LICENSE_PUBLIC_KEY_DATA = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR
e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA
1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv
busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt
EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4
ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN
qwIDAQAB
-----END PUBLIC KEY-----
"""
_PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA,
backend=default_backend())
def decode_license(license_contents, public_key_instance=None):
""" Decodes the specified license contents, returning the decoded license. """
license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY
try:
jwt_data = jwt.decode(license_contents, key=license_public_key)
except jwt.exceptions.DecodeError as de:
logger.exception('Could not decode license file')
raise LicenseDecodeError('Could not decode license found: %s' % de.message)
try:
decoded = json.loads(jwt_data.get('license', '{}'))
except ValueError as ve:
logger.exception('Could not decode license file')
raise LicenseDecodeError('Could not decode license found: %s' % ve.message)
return License(decoded)
LICENSE_VALIDATION_INTERVAL = 3600 # seconds
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count'])
def _gen_entitlement_requirements(config_obj):
config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', []))
return [
EntitlementRequirement(QUAY_ENTITLEMENT, 1),
EntitlementRequirement(QUAY_DEPLOYMENTS_ENTITLEMENT, config_regions),
]
class LicenseValidator(Thread):
"""
LicenseValidator is a thread that asynchronously reloads and validates license files.
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
synchronization primitive.
"""
def __init__(self, config_provider, *args, **kwargs):
config = config_provider.get_config() or {}
self._config_provider = config_provider
self._entitlement_requirements = _gen_entitlement_requirements(config)
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
self._license_is_insufficient = multiprocessing.Value(c_bool, True)
super(LicenseValidator, self).__init__(*args, **kwargs)
self.daemon = True
@property
def insufficient(self):
return self._license_is_insufficient.value
def compute_license_sufficiency(self):
""" Check whether all of our requirements are met, and set the status of
the result of the check, which will be used to disable the software.
Returns True if any requirements are not met, and False if all are met.
"""
try:
current_license = self._config_provider.get_license()
now = datetime.now()
any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met()
for req in self._entitlement_requirements)
logger.debug('updating license license_is_insufficient to %s', any_invalid)
except (IOError, LicenseDecodeError):
logger.exception('failed to validate license')
any_invalid = True
self._license_is_insufficient.value = any_invalid
return any_invalid
def run(self):
logger.debug('Starting license validation thread')
while True:
invalid = self.compute_license_sufficiency()
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
time.sleep(sleep_time)
def enforce_license_before_request(self, blueprint, response_func=None):
"""
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
license has become invalid, the client will receive a HTTP 402 response.
"""
if response_func is None:
def _response_func():
return make_response('License is insufficient.', 402)
response_func = _response_func
def _enforce_license():
if self.insufficient:
logger.debug('blocked interaction due to insufficient license')
return response_func()
blueprint.before_request(_enforce_license)