This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/util/license.py
Joseph Schorr 67f828279d Switch the license validator to use config_provider and have a test license
Fixes the broken tests currently which try (and fail) to read the license file
2016-10-18 11:44:13 -04:00

241 lines
8.6 KiB
Python

import json
import logging
import multiprocessing
import time
from ctypes import c_bool
from datetime import datetime, timedelta
from threading import Thread
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from dateutil import parser
from flask import make_response
import jwt
logger = logging.getLogger(__name__)
TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week
MONTHLY_GRACE_PERIOD = timedelta(30, 0) # 1 month
YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months
LICENSE_PRODUCT_NAME = "quay-enterprise"
LICENSE_FILENAME = 'license'
class LicenseError(Exception):
""" Exception raised if the license could not be read, decoded or has expired. """
pass
class LicenseDecodeError(LicenseError):
""" Exception raised if the license could not be decoded. """
pass
class LicenseValidationError(LicenseError):
""" Exception raised if the license could not be validated. """
pass
def _get_date(decoded, field):
""" Retrieves the encoded date found at the given field under the decoded license block. """
date_str = decoded.get(field)
return parser.parse(date_str).replace(tzinfo=None) if date_str else None
class LicenseExpirationDate(object):
def __init__(self, title, expiration_date, grace_period=None):
self.title = title
self.expiration_date = expiration_date
self.grace_period = grace_period or timedelta(seconds=0)
def check_expired(self, cutoff_date=None):
return self.expiration_and_grace <= (cutoff_date or datetime.now())
@property
def expiration_and_grace(self):
return self.expiration_date + self.grace_period
def __str__(self):
return 'License expiration "%s" date %s with grace %s: %s' % (self.title, self.expiration_date,
self.grace_period,
self.check_expired())
class License(object):
""" License represents a fully decoded and validated (but potentially expired) license. """
def __init__(self, decoded):
self.decoded = decoded
@property
def subscription(self):
""" Returns the Quay Enterprise subscription, if any. """
for sub in self.decoded.get('subscriptions', {}).values():
if sub.get('productName') == LICENSE_PRODUCT_NAME:
return sub
return None
@property
def is_expired(self):
cutoff_date = datetime.now()
return bool([dt for dt in self._get_expiration_dates() if dt.check_expired(cutoff_date)])
def validate(self, config):
""" Validates the license and all its entitlements against the given config. """
# Check that the license has not expired.
if self.is_expired:
raise LicenseValidationError('License has expired')
# Check the maximum number of replication regions.
max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1)
config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', []))
if max_regions != -1 and config_regions > max_regions:
msg = '{} regions configured, but license file allows up to {}'.format(config_regions,
max_regions)
raise LicenseValidationError(msg)
def _get_expiration_dates(self):
# Check if the license overall has expired.
expiration_date = _get_date(self.decoded, 'expirationDate')
if expiration_date is None:
yield LicenseExpirationDate('No valid Tectonic Account License', datetime.min)
return
yield LicenseExpirationDate('Tectonic Account License', expiration_date)
# Check for any QE subscriptions.
sub = self.subscription
if sub is None:
yield LicenseExpirationDate('No Quay Enterprise Subscription', datetime.min)
return
# Check for a trial-only license.
if sub.get('trialOnly', False):
trial_end_date = _get_date(sub, 'trialEnd')
if trial_end_date is None:
yield LicenseExpirationDate('Invalid trial subscription', datetime.min)
else:
yield LicenseExpirationDate('Trial subscription', trial_end_date, TRIAL_GRACE_PERIOD)
return
# Check for a normal license that is in trial.
service_end_date = _get_date(sub, 'serviceEnd')
if service_end_date is None:
yield LicenseExpirationDate('No valid Quay Enterprise Subscription', datetime.min)
return
if sub.get('inTrial', False):
# If the subscription is in a trial, but not a trial only
# subscription, give 7 days after trial end to update license
# to one which has been paid (they've put in a credit card and it
# might auto convert, so we could assume it will auto-renew)
yield LicenseExpirationDate('In-trial subscription', service_end_date, TRIAL_GRACE_PERIOD)
# Otherwise, check the service expiration.
duration_period = sub.get('durationPeriod', 'months')
# If the subscription is monthly, give 3 months grace period
if duration_period == "months":
yield LicenseExpirationDate('Monthly subscription', service_end_date, MONTHLY_GRACE_PERIOD)
if duration_period == "years":
yield LicenseExpirationDate('Yearly subscription', service_end_date, YEARLY_GRACE_PERIOD)
_PROD_LICENSE_PUBLIC_KEY_DATA = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR
e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA
1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv
busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt
EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4
ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN
qwIDAQAB
-----END PUBLIC KEY-----
"""
_PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA,
backend=default_backend())
def decode_license(license_contents, public_key_instance=None):
""" Decodes the specified license contents, returning the decoded license. """
license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY
try:
jwt_data = jwt.decode(license_contents, key=license_public_key)
except jwt.exceptions.DecodeError as de:
logger.exception('Could not decode license file')
raise LicenseDecodeError('Could not decode license found: %s' % de.message)
try:
decoded = json.loads(jwt_data.get('license', '{}'))
except ValueError as ve:
logger.exception('Could not decode license file')
raise LicenseDecodeError('Could not decode license found: %s' % ve.message)
return License(decoded)
LICENSE_VALIDATION_INTERVAL = 3600 # seconds
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
class LicenseValidator(Thread):
"""
LicenseValidator is a thread that asynchronously reloads and validates license files.
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
synchronization primitive.
"""
def __init__(self, config_provider, *args, **kwargs):
self._config_provider = config_provider
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
self._license_is_expired = multiprocessing.Value(c_bool, True)
super(LicenseValidator, self).__init__(*args, **kwargs)
self.daemon = True
@property
def expired(self):
return self._license_is_expired.value
def _check_expiration(self):
try:
current_license = self._config_provider.get_license()
is_expired = current_license.is_expired
logger.debug('updating license expiration to %s', is_expired)
self._license_is_expired.value = is_expired
except (IOError, LicenseError):
logger.exception('failed to validate license')
is_expired = True
self._license_is_expired.value = is_expired
return is_expired
def run(self):
logger.debug('Starting license validation thread')
while True:
expired = self._check_expiration()
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if expired else LICENSE_VALIDATION_INTERVAL
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
time.sleep(sleep_time)
def enforce_license_before_request(self, blueprint, response_func=None):
"""
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
license has become invalid, the client will receive a HTTP 402 response.
"""
if response_func is None:
def _response_func():
return make_response('License has expired.', 402)
response_func = _response_func
def _enforce_license():
if self.expired:
logger.debug('blocked interaction due to expired license')
return response_func()
blueprint.before_request(_enforce_license)