Refactor our license code to be entitlement centric.

This commit is contained in:
Jake Moshenko 2016-10-18 18:47:51 -04:00
parent d90398e9ff
commit 9f1c12e413
4 changed files with 238 additions and 130 deletions

View file

@ -19,7 +19,7 @@ from data.database import User
from util.config.configutil import add_enterprise_config_defaults
from util.config.database import sync_database_with_config
from util.config.validator import validate_service_for_config, CONFIG_FILENAMES
from util.license import decode_license, LicenseError
from util.license import decode_license, LicenseDecodeError
from data.runmigration import run_alembic_migration
from data.users import get_federated_service_name, get_users_handler
@ -283,15 +283,17 @@ class SuperUserSetAndValidateLicense(ApiResource):
license_contents = request.get_json()['license']
try:
decoded_license = decode_license(license_contents)
except LicenseError as le:
except LicenseDecodeError as le:
raise InvalidRequest(le.message)
if decoded_license.is_expired:
raise InvalidRequest('License has expired')
statuses = decoded_license.validate({})
all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
config_provider.save_license(license_contents)
return {
'decoded': decoded_license.subscription,
'decoded': {},
'success': True
}

View file

@ -11,7 +11,7 @@ from flask import request, make_response, jsonify
import features
from app import app, avatar, superusers, authentication, config_provider
from app import app, avatar, superusers, authentication, config_provider, license_validator
from auth import scopes
from auth.auth_context import get_authenticated_user
from auth.permissions import SuperUserPermission
@ -23,7 +23,7 @@ from endpoints.api.logs import get_logs, get_aggregate_logs
from data import model
from data.database import ServiceKeyApprovalType
from util.useremails import send_confirmation_email, send_recovery_email
from util.license import decode_license, LicenseError
from util.license import decode_license, LicenseDecodeError
logger = logging.getLogger(__name__)
@ -851,11 +851,13 @@ class SuperUserLicense(ApiResource):
if SuperUserPermission().can():
try:
decoded_license = config_provider.get_license()
except LicenseError as le:
except LicenseDecodeError as le:
raise InvalidRequest(le.message)
if decoded_license.is_expired:
raise InvalidRequest('License has expired')
statuses = decoded_license.validate(app.config)
all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
return {
'decoded': decoded_license.subscription,
@ -875,15 +877,20 @@ class SuperUserLicense(ApiResource):
license_contents = request.get_json()['license']
try:
decoded_license = decode_license(license_contents)
except LicenseError as le:
except LicenseDecodeError as le:
raise InvalidRequest(le.message)
if decoded_license.is_expired:
raise InvalidRequest('License has expired')
statuses = decoded_license.validate(app.config)
all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
config_provider.save_license(license_contents)
license_validator.compute_license_sufficiency()
return {
'decoded': decoded_license.subscription,
'decoded': {},
'success': True
}

View file

@ -1,7 +1,7 @@
import logging
import yaml
from util.license import LICENSE_FILENAME, LicenseError, decode_license
from util.license import LICENSE_FILENAME, LicenseDecodeError, decode_license
logger = logging.getLogger(__name__)
@ -104,13 +104,13 @@ class BaseProvider(object):
""" Returns the contents of the license file. """
if not self.has_license_file():
msg = 'Could not find license file. Please make sure it is in your config volume.'
raise LicenseError(msg)
raise LicenseDecodeError(msg)
try:
return self.get_volume_file(LICENSE_FILENAME)
except IOError:
msg = 'Could not open license file. Please make sure it is in your config volume.'
raise LicenseError(msg)
raise LicenseDecodeError(msg)
def get_license(self):
""" Returns the decoded license, if any. """

View file

@ -6,6 +6,9 @@ import time
from ctypes import c_bool
from datetime import datetime, timedelta
from threading import Thread
from functools import total_ordering
from enum import Enum, IntEnum
from collections import namedtuple
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_public_key
@ -21,130 +24,209 @@ logger = logging.getLogger(__name__)
TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week
MONTHLY_GRACE_PERIOD = timedelta(335, 0) # 11 months
YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months
LICENSE_PRODUCT_NAME = "quay-enterprise"
LICENSE_FILENAME = 'license'
class LicenseError(Exception):
class LicenseDecodeError(Exception):
""" Exception raised if the license could not be read, decoded or has expired. """
pass
class LicenseDecodeError(LicenseError):
""" Exception raised if the license could not be decoded. """
pass
class LicenseValidationError(LicenseError):
""" Exception raised if the license could not be validated. """
pass
def _get_date(decoded, field):
def _get_date(decoded, field, default_date=datetime.min):
""" Retrieves the encoded date found at the given field under the decoded license block. """
date_str = decoded.get(field)
return parser.parse(date_str).replace(tzinfo=None) if date_str else None
return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date
class LicenseExpirationDate(object):
def __init__(self, title, expiration_date, grace_period=None):
self.title = title
self.expiration_date = expiration_date
self.grace_period = grace_period or timedelta(seconds=0)
@total_ordering
class Entitlement(object):
""" An entitlement is a specific piece of software or functionality granted
by a license. It has an expiration date, as well as the count of the
things being granted. Entitlements are orderable by their counts.
"""
def __init__(self, entitlement_name, count, product_name, expiration):
self.name = entitlement_name
self.count = count
self.product_name = product_name
self.expiration = expiration
def check_expired(self, cutoff_date=None):
return self.expiration_and_grace <= (cutoff_date or datetime.now())
def __lt__(self, rhs):
return self.count < rhs.count
def __repr__(self):
return str(dict(
name=self.name,
count=self.count,
product_name=self.product_name,
expiration=repr(self.expiration),
))
class ExpirationType(Enum):
""" An enum which represents the different possible types of expirations. If
you posess an expired enum, you can use this to figure out at what level
the expiration was most restrictive.
"""
license_wide = 'License Wide Expiration'
trial_only = 'Trial Only Expiration'
in_trial = 'In-Trial Expiration'
monthly = 'Monthly Subscription Expiration'
yearly = 'Yearly Subscription Expiration'
@total_ordering
class Expiration(object):
""" An Expiration is an orderable representation of an expiration date and a
grace period. If you sort Expiration objects, they will be sorted by the
actual cutoff date, which is the combination of the expiration date and
the grace period.
"""
def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)):
self.expiration_type = expiration_type
self.expiration_date = exp_date
self.grace_period = grace_period
@property
def expiration_and_grace(self):
def expires_at(self):
return self.expiration_date + self.grace_period
def __str__(self):
return 'License expiration "%s" date %s with grace %s: %s' % (self.title, self.expiration_date,
self.grace_period,
self.check_expired())
def is_expired(self, now):
""" Check if the current object should already be considered expired when
compared with the passed in datetime object.
"""
return self.expires_at < now
def __lt__(self, rhs):
return self.expires_at < rhs.expires_at
def __repr__(self):
return str(dict(
expiration_type=repr(self.expiration_type),
expiration_date=repr(self.expiration_date),
grace_period=repr(self.grace_period),
))
class EntitlementStatus(IntEnum):
""" An EntitlementStatus represent the current effectiveness of an
Entitlement when compared with its corresponding requirement. As an
example, if the software requires 9 items, and the Entitlement only
provides for 7, you would use an insufficient_count status.
"""
met = 0
expired = 1
insufficient_count = 2
no_matching = 3
@total_ordering
class EntitlementValidationResult(object):
""" An EntitlementValidationResult encodes the combination of a specific
entitlement and the software requirement which caused it to be examined.
They are orderable by the value of the EntitlementStatus enum, and will
in general be sorted by most to least satisfiable status type.
"""
def __init__(self, requirement, created_at, entitlement=None):
self.requirement = requirement
self._created_at = created_at
self.entitlement = entitlement
def get_status(self):
""" Returns the EntitlementStatus when comparing the specified Entitlement
with the corresponding requirement.
"""
if self.entitlement is not None:
if self.entitlement.expiration.is_expired(self._created_at):
return EntitlementStatus.expired
if self.entitlement.count < self.requirement.count:
return EntitlementStatus.insufficient_count
return EntitlementStatus.met
return EntitlementStatus.no_matching
def is_met(self):
""" Returns whether this specific EntitlementValidationResult meets all
of the criteria for being sufficient, including unexpired (or in the
grace period), and with a sufficient count.
"""
return self.get_status() == EntitlementStatus.met
def __lt__(self, rhs):
return self.get_status() < rhs.get_status()
def __repr__(self):
return str(dict(
requirement=repr(self.requirement),
created_at=repr(self._created_at),
entitlement=repr(self.entitlement),
))
class License(object):
""" License represents a fully decoded and validated (but potentially expired) license. """
def __init__(self, decoded):
self.decoded = decoded
@property
def subscription(self):
""" Returns the Quay Enterprise subscription, if any. """
def validate_entitlement_requirement(self, entitlement_req, check_time):
all_active_entitlements = list(self._find_entitlements(entitlement_req.name))
if len(all_active_entitlements) == 0:
return EntitlementValidationResult(entitlement_req, check_time)
entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent)
for ent in all_active_entitlements]
entitlement_results.sort()
return entitlement_results[0]
def _find_entitlements(self, entitlement_name):
license_expiration = Expiration(
ExpirationType.license_wide,
_get_date(self.decoded, 'expirationDate'),
)
for sub in self.decoded.get('subscriptions', {}).values():
if sub.get('productName') == LICENSE_PRODUCT_NAME:
return sub
entitlement_count = sub.get('entitlements', {}).get(entitlement_name)
return None
if entitlement_count is not None:
entitlement_expiration = min(self._sub_expiration(sub), license_expiration)
yield Entitlement(
entitlement_name,
entitlement_count,
sub.get('productName', 'unknown'),
entitlement_expiration,
)
@staticmethod
def _sub_expiration(subscription):
# A trial license has its own end logic, and uses the trialEnd property
if subscription.get('trialOnly', False):
trial_expiration = Expiration(
ExpirationType.trial_only,
_get_date(subscription, 'trialEnd'),
TRIAL_GRACE_PERIOD,
)
return trial_expiration
# From here we always use the serviceEnd
service_end = _get_date(subscription, 'serviceEnd')
if subscription.get('inTrial', False):
return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD)
if subscription.get('durationPeriod') == 'yearly':
return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD)
# We assume monthly license unless specified otherwise
return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD)
@property
def is_expired(self):
cutoff_date = datetime.now()
return bool([dt for dt in self._get_expiration_dates() if dt.check_expired(cutoff_date)])
def validate(self, config):
""" Validates the license and all its entitlements against the given config. """
# Check that the license has not expired.
if self.is_expired:
raise LicenseValidationError('License has expired')
# Check the maximum number of replication regions.
max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1)
config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', []))
if max_regions != -1 and config_regions > max_regions:
msg = '{} regions configured, but license file allows up to {}'.format(config_regions,
max_regions)
raise LicenseValidationError(msg)
def _get_expiration_dates(self):
# Check if the license overall has expired.
expiration_date = _get_date(self.decoded, 'expirationDate')
if expiration_date is None:
yield LicenseExpirationDate('No valid Tectonic Account License', datetime.min)
return
yield LicenseExpirationDate('Tectonic Account License', expiration_date)
# Check for any QE subscriptions.
sub = self.subscription
if sub is None:
yield LicenseExpirationDate('No Quay Enterprise Subscription', datetime.min)
return
# Check for a trial-only license.
if sub.get('trialOnly', False):
trial_end_date = _get_date(sub, 'trialEnd')
if trial_end_date is None:
yield LicenseExpirationDate('Invalid trial subscription', datetime.min)
else:
yield LicenseExpirationDate('Trial subscription', trial_end_date, TRIAL_GRACE_PERIOD)
return
# Check for a normal license that is in trial.
service_end_date = _get_date(sub, 'serviceEnd')
if service_end_date is None:
yield LicenseExpirationDate('No valid Quay Enterprise Subscription', datetime.min)
return
if sub.get('inTrial', False):
# If the subscription is in a trial, but not a trial only
# subscription, give 7 days after trial end to update license
# to one which has been paid (they've put in a credit card and it
# might auto convert, so we could assume it will auto-renew)
yield LicenseExpirationDate('In-trial subscription', service_end_date, TRIAL_GRACE_PERIOD)
# Otherwise, check the service expiration.
duration_period = sub.get('durationPeriod', 'months')
# If the subscription is monthly, give 3 months grace period
if duration_period == "months":
yield LicenseExpirationDate('Monthly subscription', service_end_date, MONTHLY_GRACE_PERIOD)
if duration_period == "years":
yield LicenseExpirationDate('Yearly subscription', service_end_date, YEARLY_GRACE_PERIOD)
""" Returns a list of EntitlementValidationResult objects, one per requirement.
"""
requirements = _gen_entitlement_requirements(config)
now = datetime.now()
return [self.validate_entitlement_requirement(req, now) for req in requirements]
_PROD_LICENSE_PUBLIC_KEY_DATA = """
@ -183,6 +265,17 @@ LICENSE_VALIDATION_INTERVAL = 3600 # seconds
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count'])
def _gen_entitlement_requirements(config_obj):
config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', []))
return [
EntitlementRequirement('software.quay', 1),
EntitlementRequirement('software.quay.regions', config_regions),
]
class LicenseValidator(Thread):
"""
LicenseValidator is a thread that asynchronously reloads and validates license files.
@ -192,35 +285,41 @@ class LicenseValidator(Thread):
"""
def __init__(self, config_provider, *args, **kwargs):
self._config_provider = config_provider
self._entitlement_requirements = _gen_entitlement_requirements(config_provider.get_config())
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
self._license_is_expired = multiprocessing.Value(c_bool, True)
self._license_is_insufficient = multiprocessing.Value(c_bool, True)
super(LicenseValidator, self).__init__(*args, **kwargs)
self.daemon = True
@property
def expired(self):
return self._license_is_expired.value
def insufficient(self):
return self._license_is_insufficient.value
def _check_expiration(self):
def compute_license_sufficiency(self):
""" Check whether all of our requirements are met, and set the status of
the result of the check, which will be used to disable the software.
Returns True if any requirements are not met, and False if all are met.
"""
try:
current_license = self._config_provider.get_license()
is_expired = current_license.is_expired
logger.debug('updating license expiration to %s', is_expired)
self._license_is_expired.value = is_expired
except (IOError, LicenseError):
now = datetime.now()
any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met()
for req in self._entitlement_requirements)
logger.debug('updating license license_is_insufficient to %s', any_invalid)
except (IOError, LicenseDecodeError):
logger.exception('failed to validate license')
is_expired = True
self._license_is_expired.value = is_expired
any_invalid = True
return is_expired
self._license_is_insufficient.value = any_invalid
return any_invalid
def run(self):
logger.debug('Starting license validation thread')
while True:
expired = self._check_expiration()
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if expired else LICENSE_VALIDATION_INTERVAL
invalid = self.compute_license_sufficiency()
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
time.sleep(sleep_time)
@ -231,11 +330,11 @@ class LicenseValidator(Thread):
"""
if response_func is None:
def _response_func():
return make_response('License has expired.', 402)
return make_response('License is insufficient.', 402)
response_func = _response_func
def _enforce_license():
if self.expired:
logger.debug('blocked interaction due to expired license')
if self.insufficient:
logger.debug('blocked interaction due to insufficient license')
return response_func()
blueprint.before_request(_enforce_license)