Refactor our license code to be entitlement centric.

This commit is contained in:
Jake Moshenko 2016-10-18 18:47:51 -04:00
parent d90398e9ff
commit 9f1c12e413
4 changed files with 238 additions and 130 deletions

View file

@ -19,7 +19,7 @@ from data.database import User
from util.config.configutil import add_enterprise_config_defaults from util.config.configutil import add_enterprise_config_defaults
from util.config.database import sync_database_with_config from util.config.database import sync_database_with_config
from util.config.validator import validate_service_for_config, CONFIG_FILENAMES from util.config.validator import validate_service_for_config, CONFIG_FILENAMES
from util.license import decode_license, LicenseError from util.license import decode_license, LicenseDecodeError
from data.runmigration import run_alembic_migration from data.runmigration import run_alembic_migration
from data.users import get_federated_service_name, get_users_handler from data.users import get_federated_service_name, get_users_handler
@ -283,15 +283,17 @@ class SuperUserSetAndValidateLicense(ApiResource):
license_contents = request.get_json()['license'] license_contents = request.get_json()['license']
try: try:
decoded_license = decode_license(license_contents) decoded_license = decode_license(license_contents)
except LicenseError as le: except LicenseDecodeError as le:
raise InvalidRequest(le.message) raise InvalidRequest(le.message)
if decoded_license.is_expired: statuses = decoded_license.validate({})
raise InvalidRequest('License has expired') all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
config_provider.save_license(license_contents) config_provider.save_license(license_contents)
return { return {
'decoded': decoded_license.subscription, 'decoded': {},
'success': True 'success': True
} }

View file

@ -11,7 +11,7 @@ from flask import request, make_response, jsonify
import features import features
from app import app, avatar, superusers, authentication, config_provider from app import app, avatar, superusers, authentication, config_provider, license_validator
from auth import scopes from auth import scopes
from auth.auth_context import get_authenticated_user from auth.auth_context import get_authenticated_user
from auth.permissions import SuperUserPermission from auth.permissions import SuperUserPermission
@ -23,7 +23,7 @@ from endpoints.api.logs import get_logs, get_aggregate_logs
from data import model from data import model
from data.database import ServiceKeyApprovalType from data.database import ServiceKeyApprovalType
from util.useremails import send_confirmation_email, send_recovery_email from util.useremails import send_confirmation_email, send_recovery_email
from util.license import decode_license, LicenseError from util.license import decode_license, LicenseDecodeError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -851,11 +851,13 @@ class SuperUserLicense(ApiResource):
if SuperUserPermission().can(): if SuperUserPermission().can():
try: try:
decoded_license = config_provider.get_license() decoded_license = config_provider.get_license()
except LicenseError as le: except LicenseDecodeError as le:
raise InvalidRequest(le.message) raise InvalidRequest(le.message)
if decoded_license.is_expired: statuses = decoded_license.validate(app.config)
raise InvalidRequest('License has expired') all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
return { return {
'decoded': decoded_license.subscription, 'decoded': decoded_license.subscription,
@ -875,15 +877,20 @@ class SuperUserLicense(ApiResource):
license_contents = request.get_json()['license'] license_contents = request.get_json()['license']
try: try:
decoded_license = decode_license(license_contents) decoded_license = decode_license(license_contents)
except LicenseError as le: except LicenseDecodeError as le:
raise InvalidRequest(le.message) raise InvalidRequest(le.message)
if decoded_license.is_expired: statuses = decoded_license.validate(app.config)
raise InvalidRequest('License has expired') all_met = all(status.is_met() for status in statuses)
if not all_met:
raise InvalidRequest('License is insufficient')
config_provider.save_license(license_contents) config_provider.save_license(license_contents)
license_validator.compute_license_sufficiency()
return { return {
'decoded': decoded_license.subscription, 'decoded': {},
'success': True 'success': True
} }

View file

@ -1,7 +1,7 @@
import logging import logging
import yaml import yaml
from util.license import LICENSE_FILENAME, LicenseError, decode_license from util.license import LICENSE_FILENAME, LicenseDecodeError, decode_license
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -104,13 +104,13 @@ class BaseProvider(object):
""" Returns the contents of the license file. """ """ Returns the contents of the license file. """
if not self.has_license_file(): if not self.has_license_file():
msg = 'Could not find license file. Please make sure it is in your config volume.' msg = 'Could not find license file. Please make sure it is in your config volume.'
raise LicenseError(msg) raise LicenseDecodeError(msg)
try: try:
return self.get_volume_file(LICENSE_FILENAME) return self.get_volume_file(LICENSE_FILENAME)
except IOError: except IOError:
msg = 'Could not open license file. Please make sure it is in your config volume.' msg = 'Could not open license file. Please make sure it is in your config volume.'
raise LicenseError(msg) raise LicenseDecodeError(msg)
def get_license(self): def get_license(self):
""" Returns the decoded license, if any. """ """ Returns the decoded license, if any. """

View file

@ -6,6 +6,9 @@ import time
from ctypes import c_bool from ctypes import c_bool
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Thread from threading import Thread
from functools import total_ordering
from enum import Enum, IntEnum
from collections import namedtuple
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_public_key from cryptography.hazmat.primitives.serialization import load_pem_public_key
@ -21,130 +24,209 @@ logger = logging.getLogger(__name__)
TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week
MONTHLY_GRACE_PERIOD = timedelta(335, 0) # 11 months MONTHLY_GRACE_PERIOD = timedelta(335, 0) # 11 months
YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months
LICENSE_PRODUCT_NAME = "quay-enterprise"
LICENSE_FILENAME = 'license' LICENSE_FILENAME = 'license'
class LicenseError(Exception): class LicenseDecodeError(Exception):
""" Exception raised if the license could not be read, decoded or has expired. """ """ Exception raised if the license could not be read, decoded or has expired. """
pass pass
class LicenseDecodeError(LicenseError): def _get_date(decoded, field, default_date=datetime.min):
""" Exception raised if the license could not be decoded. """
pass
class LicenseValidationError(LicenseError):
""" Exception raised if the license could not be validated. """
pass
def _get_date(decoded, field):
""" Retrieves the encoded date found at the given field under the decoded license block. """ """ Retrieves the encoded date found at the given field under the decoded license block. """
date_str = decoded.get(field) date_str = decoded.get(field)
return parser.parse(date_str).replace(tzinfo=None) if date_str else None return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date
class LicenseExpirationDate(object): @total_ordering
def __init__(self, title, expiration_date, grace_period=None): class Entitlement(object):
self.title = title """ An entitlement is a specific piece of software or functionality granted
self.expiration_date = expiration_date by a license. It has an expiration date, as well as the count of the
self.grace_period = grace_period or timedelta(seconds=0) things being granted. Entitlements are orderable by their counts.
"""
def __init__(self, entitlement_name, count, product_name, expiration):
self.name = entitlement_name
self.count = count
self.product_name = product_name
self.expiration = expiration
def check_expired(self, cutoff_date=None): def __lt__(self, rhs):
return self.expiration_and_grace <= (cutoff_date or datetime.now()) return self.count < rhs.count
def __repr__(self):
return str(dict(
name=self.name,
count=self.count,
product_name=self.product_name,
expiration=repr(self.expiration),
))
class ExpirationType(Enum):
""" An enum which represents the different possible types of expirations. If
you posess an expired enum, you can use this to figure out at what level
the expiration was most restrictive.
"""
license_wide = 'License Wide Expiration'
trial_only = 'Trial Only Expiration'
in_trial = 'In-Trial Expiration'
monthly = 'Monthly Subscription Expiration'
yearly = 'Yearly Subscription Expiration'
@total_ordering
class Expiration(object):
""" An Expiration is an orderable representation of an expiration date and a
grace period. If you sort Expiration objects, they will be sorted by the
actual cutoff date, which is the combination of the expiration date and
the grace period.
"""
def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)):
self.expiration_type = expiration_type
self.expiration_date = exp_date
self.grace_period = grace_period
@property @property
def expiration_and_grace(self): def expires_at(self):
return self.expiration_date + self.grace_period return self.expiration_date + self.grace_period
def __str__(self): def is_expired(self, now):
return 'License expiration "%s" date %s with grace %s: %s' % (self.title, self.expiration_date, """ Check if the current object should already be considered expired when
self.grace_period, compared with the passed in datetime object.
self.check_expired()) """
return self.expires_at < now
def __lt__(self, rhs):
return self.expires_at < rhs.expires_at
def __repr__(self):
return str(dict(
expiration_type=repr(self.expiration_type),
expiration_date=repr(self.expiration_date),
grace_period=repr(self.grace_period),
))
class EntitlementStatus(IntEnum):
""" An EntitlementStatus represent the current effectiveness of an
Entitlement when compared with its corresponding requirement. As an
example, if the software requires 9 items, and the Entitlement only
provides for 7, you would use an insufficient_count status.
"""
met = 0
expired = 1
insufficient_count = 2
no_matching = 3
@total_ordering
class EntitlementValidationResult(object):
""" An EntitlementValidationResult encodes the combination of a specific
entitlement and the software requirement which caused it to be examined.
They are orderable by the value of the EntitlementStatus enum, and will
in general be sorted by most to least satisfiable status type.
"""
def __init__(self, requirement, created_at, entitlement=None):
self.requirement = requirement
self._created_at = created_at
self.entitlement = entitlement
def get_status(self):
""" Returns the EntitlementStatus when comparing the specified Entitlement
with the corresponding requirement.
"""
if self.entitlement is not None:
if self.entitlement.expiration.is_expired(self._created_at):
return EntitlementStatus.expired
if self.entitlement.count < self.requirement.count:
return EntitlementStatus.insufficient_count
return EntitlementStatus.met
return EntitlementStatus.no_matching
def is_met(self):
""" Returns whether this specific EntitlementValidationResult meets all
of the criteria for being sufficient, including unexpired (or in the
grace period), and with a sufficient count.
"""
return self.get_status() == EntitlementStatus.met
def __lt__(self, rhs):
return self.get_status() < rhs.get_status()
def __repr__(self):
return str(dict(
requirement=repr(self.requirement),
created_at=repr(self._created_at),
entitlement=repr(self.entitlement),
))
class License(object): class License(object):
""" License represents a fully decoded and validated (but potentially expired) license. """ """ License represents a fully decoded and validated (but potentially expired) license. """
def __init__(self, decoded): def __init__(self, decoded):
self.decoded = decoded self.decoded = decoded
@property def validate_entitlement_requirement(self, entitlement_req, check_time):
def subscription(self): all_active_entitlements = list(self._find_entitlements(entitlement_req.name))
""" Returns the Quay Enterprise subscription, if any. """
if len(all_active_entitlements) == 0:
return EntitlementValidationResult(entitlement_req, check_time)
entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent)
for ent in all_active_entitlements]
entitlement_results.sort()
return entitlement_results[0]
def _find_entitlements(self, entitlement_name):
license_expiration = Expiration(
ExpirationType.license_wide,
_get_date(self.decoded, 'expirationDate'),
)
for sub in self.decoded.get('subscriptions', {}).values(): for sub in self.decoded.get('subscriptions', {}).values():
if sub.get('productName') == LICENSE_PRODUCT_NAME: entitlement_count = sub.get('entitlements', {}).get(entitlement_name)
return sub
return None if entitlement_count is not None:
entitlement_expiration = min(self._sub_expiration(sub), license_expiration)
yield Entitlement(
entitlement_name,
entitlement_count,
sub.get('productName', 'unknown'),
entitlement_expiration,
)
@staticmethod
def _sub_expiration(subscription):
# A trial license has its own end logic, and uses the trialEnd property
if subscription.get('trialOnly', False):
trial_expiration = Expiration(
ExpirationType.trial_only,
_get_date(subscription, 'trialEnd'),
TRIAL_GRACE_PERIOD,
)
return trial_expiration
# From here we always use the serviceEnd
service_end = _get_date(subscription, 'serviceEnd')
if subscription.get('inTrial', False):
return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD)
if subscription.get('durationPeriod') == 'yearly':
return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD)
# We assume monthly license unless specified otherwise
return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD)
@property
def is_expired(self):
cutoff_date = datetime.now()
return bool([dt for dt in self._get_expiration_dates() if dt.check_expired(cutoff_date)])
def validate(self, config): def validate(self, config):
""" Validates the license and all its entitlements against the given config. """ """ Returns a list of EntitlementValidationResult objects, one per requirement.
# Check that the license has not expired. """
if self.is_expired: requirements = _gen_entitlement_requirements(config)
raise LicenseValidationError('License has expired') now = datetime.now()
return [self.validate_entitlement_requirement(req, now) for req in requirements]
# Check the maximum number of replication regions.
max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1)
config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', []))
if max_regions != -1 and config_regions > max_regions:
msg = '{} regions configured, but license file allows up to {}'.format(config_regions,
max_regions)
raise LicenseValidationError(msg)
def _get_expiration_dates(self):
# Check if the license overall has expired.
expiration_date = _get_date(self.decoded, 'expirationDate')
if expiration_date is None:
yield LicenseExpirationDate('No valid Tectonic Account License', datetime.min)
return
yield LicenseExpirationDate('Tectonic Account License', expiration_date)
# Check for any QE subscriptions.
sub = self.subscription
if sub is None:
yield LicenseExpirationDate('No Quay Enterprise Subscription', datetime.min)
return
# Check for a trial-only license.
if sub.get('trialOnly', False):
trial_end_date = _get_date(sub, 'trialEnd')
if trial_end_date is None:
yield LicenseExpirationDate('Invalid trial subscription', datetime.min)
else:
yield LicenseExpirationDate('Trial subscription', trial_end_date, TRIAL_GRACE_PERIOD)
return
# Check for a normal license that is in trial.
service_end_date = _get_date(sub, 'serviceEnd')
if service_end_date is None:
yield LicenseExpirationDate('No valid Quay Enterprise Subscription', datetime.min)
return
if sub.get('inTrial', False):
# If the subscription is in a trial, but not a trial only
# subscription, give 7 days after trial end to update license
# to one which has been paid (they've put in a credit card and it
# might auto convert, so we could assume it will auto-renew)
yield LicenseExpirationDate('In-trial subscription', service_end_date, TRIAL_GRACE_PERIOD)
# Otherwise, check the service expiration.
duration_period = sub.get('durationPeriod', 'months')
# If the subscription is monthly, give 3 months grace period
if duration_period == "months":
yield LicenseExpirationDate('Monthly subscription', service_end_date, MONTHLY_GRACE_PERIOD)
if duration_period == "years":
yield LicenseExpirationDate('Yearly subscription', service_end_date, YEARLY_GRACE_PERIOD)
_PROD_LICENSE_PUBLIC_KEY_DATA = """ _PROD_LICENSE_PUBLIC_KEY_DATA = """
@ -183,6 +265,17 @@ LICENSE_VALIDATION_INTERVAL = 3600 # seconds
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count'])
def _gen_entitlement_requirements(config_obj):
config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', []))
return [
EntitlementRequirement('software.quay', 1),
EntitlementRequirement('software.quay.regions', config_regions),
]
class LicenseValidator(Thread): class LicenseValidator(Thread):
""" """
LicenseValidator is a thread that asynchronously reloads and validates license files. LicenseValidator is a thread that asynchronously reloads and validates license files.
@ -192,35 +285,41 @@ class LicenseValidator(Thread):
""" """
def __init__(self, config_provider, *args, **kwargs): def __init__(self, config_provider, *args, **kwargs):
self._config_provider = config_provider self._config_provider = config_provider
self._entitlement_requirements = _gen_entitlement_requirements(config_provider.get_config())
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that. # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
self._license_is_expired = multiprocessing.Value(c_bool, True) self._license_is_insufficient = multiprocessing.Value(c_bool, True)
super(LicenseValidator, self).__init__(*args, **kwargs) super(LicenseValidator, self).__init__(*args, **kwargs)
self.daemon = True self.daemon = True
@property @property
def expired(self): def insufficient(self):
return self._license_is_expired.value return self._license_is_insufficient.value
def _check_expiration(self): def compute_license_sufficiency(self):
""" Check whether all of our requirements are met, and set the status of
the result of the check, which will be used to disable the software.
Returns True if any requirements are not met, and False if all are met.
"""
try: try:
current_license = self._config_provider.get_license() current_license = self._config_provider.get_license()
is_expired = current_license.is_expired now = datetime.now()
logger.debug('updating license expiration to %s', is_expired) any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met()
self._license_is_expired.value = is_expired for req in self._entitlement_requirements)
except (IOError, LicenseError): logger.debug('updating license license_is_insufficient to %s', any_invalid)
except (IOError, LicenseDecodeError):
logger.exception('failed to validate license') logger.exception('failed to validate license')
is_expired = True any_invalid = True
self._license_is_expired.value = is_expired
return is_expired self._license_is_insufficient.value = any_invalid
return any_invalid
def run(self): def run(self):
logger.debug('Starting license validation thread') logger.debug('Starting license validation thread')
while True: while True:
expired = self._check_expiration() invalid = self.compute_license_sufficiency()
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if expired else LICENSE_VALIDATION_INTERVAL sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL
logger.debug('waiting %d seconds before retrying to validate license', sleep_time) logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
time.sleep(sleep_time) time.sleep(sleep_time)
@ -231,11 +330,11 @@ class LicenseValidator(Thread):
""" """
if response_func is None: if response_func is None:
def _response_func(): def _response_func():
return make_response('License has expired.', 402) return make_response('License is insufficient.', 402)
response_func = _response_func response_func = _response_func
def _enforce_license(): def _enforce_license():
if self.expired: if self.insufficient:
logger.debug('blocked interaction due to expired license') logger.debug('blocked interaction due to insufficient license')
return response_func() return response_func()
blueprint.before_request(_enforce_license) blueprint.before_request(_enforce_license)