211 lines
7.4 KiB
Python
211 lines
7.4 KiB
Python
import json
|
|
import logging
|
|
import multiprocessing
|
|
import time
|
|
|
|
from ctypes import c_bool
|
|
from datetime import datetime, timedelta
|
|
from threading import Thread
|
|
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
|
from dateutil import parser
|
|
from flask import abort
|
|
|
|
import jwt
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
TRIAL_GRACE_PERIOD = timedelta(7, 0) # 1 week
|
|
MONTHLY_GRACE_PERIOD = timedelta(30, 0) # 1 month
|
|
YEARLY_GRACE_PERIOD = timedelta(90, 0) # 3 months
|
|
LICENSE_PRODUCT_NAME = "quay-enterprise"
|
|
LICENSE_FILENAME = 'license'
|
|
|
|
|
|
class LicenseError(Exception):
|
|
""" Exception raised if the license could not be read, decoded or has expired. """
|
|
pass
|
|
|
|
|
|
class LicenseDecodeError(LicenseError):
|
|
""" Exception raised if the license could not be decoded. """
|
|
pass
|
|
|
|
|
|
class LicenseValidationError(LicenseError):
|
|
""" Exception raised if the license could not be validated. """
|
|
pass
|
|
|
|
|
|
def _get_date(decoded, field):
|
|
""" Retrieves the encoded date found at the given field under the decoded license block. """
|
|
date_str = decoded.get(field)
|
|
if date_str:
|
|
return parser.parse(date_str).replace(tzinfo=None)
|
|
|
|
return datetime.now() - timedelta(days=2)
|
|
|
|
|
|
class License(object):
|
|
""" License represents a fully decoded and validated (but potentially expired) license. """
|
|
def __init__(self, decoded):
|
|
self.decoded = decoded
|
|
|
|
@property
|
|
def subscription(self):
|
|
""" Returns the Quay Enterprise subscription, if any. """
|
|
for sub in self.decoded.get('subscriptions', {}).values():
|
|
if sub.get('productName') == LICENSE_PRODUCT_NAME:
|
|
return sub
|
|
|
|
return None
|
|
|
|
@property
|
|
def is_expired(self):
|
|
return self._get_expired(datetime.now())
|
|
|
|
def validate(self, config):
|
|
""" Validates the license and all its entitlements against the given config. """
|
|
# Check that the license has not expired.
|
|
if self.is_expired:
|
|
raise LicenseValidationError('License has expired')
|
|
|
|
# Check the maximum number of replication regions.
|
|
max_regions = min(self.decoded.get('entitlements', {}).get('software.quay.regions', 1), 1)
|
|
config_regions = len(config.get('DISTRIBUTED_STORAGE_CONFIG', []))
|
|
if max_regions != -1 and config_regions > max_regions:
|
|
msg = '{} regions configured, but license file allows up to {}'.format(config_regions,
|
|
max_regions)
|
|
raise LicenseValidationError(msg)
|
|
|
|
def _get_expired(self, compare_date):
|
|
# Check if the license overall has expired.
|
|
expiration_date = _get_date(self.decoded, 'expirationDate')
|
|
if expiration_date <= compare_date:
|
|
logger.debug('License expired on %s', expiration_date)
|
|
return True
|
|
|
|
# Check for any QE subscriptions.
|
|
sub = self.subscription
|
|
if sub is None:
|
|
return True
|
|
|
|
# Check for a trial-only license.
|
|
if sub.get('trialOnly', False):
|
|
trial_end_date = _get_date(sub, 'trialEnd')
|
|
logger.debug('Trial-only license expires on %s', trial_end_date)
|
|
return trial_end_date <= (compare_date - TRIAL_GRACE_PERIOD)
|
|
|
|
# Check for a normal license that is in trial.
|
|
service_end_date = _get_date(sub, 'serviceEnd')
|
|
if sub.get('inTrial', False):
|
|
# If the subscription is in a trial, but not a trial only
|
|
# subscription, give 7 days after trial end to update license
|
|
# to one which has been paid (they've put in a credit card and it
|
|
# might auto convert, so we could assume it will auto-renew)
|
|
logger.debug('In-trial license expires on %s', service_end_date)
|
|
return service_end_date <= (compare_date - TRIAL_GRACE_PERIOD)
|
|
|
|
# Otherwise, check the service expiration.
|
|
duration_period = sub.get('durationPeriod', 'months')
|
|
|
|
# If the subscription is monthly, give 3 months grace period
|
|
if duration_period == "months":
|
|
logger.debug('Monthly license expires on %s', service_end_date)
|
|
return service_end_date <= (compare_date - MONTHLY_GRACE_PERIOD)
|
|
|
|
if duration_period == "years":
|
|
logger.debug('Yearly license expires on %s', service_end_date)
|
|
return service_end_date <= (compare_date - YEARLY_GRACE_PERIOD)
|
|
|
|
return True
|
|
|
|
|
|
_PROD_LICENSE_PUBLIC_KEY_DATA = """
|
|
-----BEGIN PUBLIC KEY-----
|
|
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR
|
|
e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA
|
|
1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv
|
|
busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt
|
|
EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4
|
|
ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN
|
|
qwIDAQAB
|
|
-----END PUBLIC KEY-----
|
|
"""
|
|
_PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA,
|
|
backend=default_backend())
|
|
|
|
def decode_license(license_contents, public_key_instance=None):
|
|
""" Decodes the specified license contents, returning the decoded license. """
|
|
license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY
|
|
try:
|
|
jwt_data = jwt.decode(license_contents, key=license_public_key)
|
|
except jwt.exceptions.DecodeError as de:
|
|
logger.exception('Could not decode license file')
|
|
raise LicenseDecodeError('Could not decode license found: %s' % de.message)
|
|
|
|
try:
|
|
decoded = json.loads(jwt_data.get('license', '{}'))
|
|
except ValueError as ve:
|
|
logger.exception('Could not decode license file')
|
|
raise LicenseDecodeError('Could not decode license found: %s' % ve.message)
|
|
|
|
return License(decoded)
|
|
|
|
|
|
LICENSE_VALIDATION_INTERVAL = 3600 # seconds
|
|
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds
|
|
|
|
|
|
class LicenseValidator(Thread):
|
|
"""
|
|
LicenseValidator is a thread that asynchronously reloads and validates license files.
|
|
|
|
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
|
|
synchronization primitive.
|
|
"""
|
|
def __init__(self, license_path):
|
|
self._license_path = license_path
|
|
|
|
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
|
|
self._license_is_expired = multiprocessing.Value(c_bool, True)
|
|
|
|
super(LicenseValidator, self).__init__()
|
|
|
|
@property
|
|
def expired(self):
|
|
return self._license_is_expired.value
|
|
|
|
def _check_expiration(self):
|
|
try:
|
|
with open(self._license_path) as f:
|
|
current_license = decode_license(f.read())
|
|
logger.debug('updating license expiration to %s', current_license.is_expired)
|
|
self._license_is_expired.value = current_license.is_expired
|
|
return current_license.is_expired
|
|
except (IOError, LicenseError):
|
|
logger.exception('failed to validate license')
|
|
self._license_is_expired.value = True
|
|
return True
|
|
|
|
def run(self):
|
|
logger.debug('Starting license validation thread')
|
|
while True:
|
|
expired = self._check_expiration()
|
|
sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if expired else LICENSE_VALIDATION_INTERVAL
|
|
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
|
|
time.sleep(sleep_time)
|
|
|
|
|
|
def enforce_license_before_request(license_validator, blueprint):
|
|
"""
|
|
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
|
|
license has become invalid, the client will receive a HTTP 402 response.
|
|
"""
|
|
def _enforce_license():
|
|
if license_validator.expired:
|
|
abort(402)
|
|
blueprint.before_request(_enforce_license)
|