diff --git a/endpoints/v1/__init__.py b/endpoints/v1/__init__.py index 1407a1b25..18ef430c4 100644 --- a/endpoints/v1/__init__.py +++ b/endpoints/v1/__init__.py @@ -2,12 +2,11 @@ from flask import Blueprint, make_response from app import metric_queue, license_validator from endpoints.decorators import anon_protect, anon_allowed -from util.license import enforce_license_before_request from util.metrics.metricqueue import time_blueprint v1_bp = Blueprint('v1', __name__) -enforce_license_before_request(license_validator, v1_bp) +license_validator.enforce_license_before_request(v1_bp) time_blueprint(v1_bp, metric_queue) diff --git a/endpoints/v2/__init__.py b/endpoints/v2/__init__.py index dda9baca4..aab985353 100644 --- a/endpoints/v2/__init__.py +++ b/endpoints/v2/__init__.py @@ -18,7 +18,6 @@ from data import model from endpoints.decorators import anon_protect, anon_allowed from endpoints.v2.errors import V2RegistryException, Unauthorized from util.http import abort -from util.license import enforce_license_before_request from util.metrics.metricqueue import time_blueprint from util.registry.dockerver import docker_version from util.pagination import encrypt_page_token, decrypt_page_token @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) v2_bp = Blueprint('v2', __name__) -enforce_license_before_request(license_validator, v2_bp) +license_validator.enforce_license_before_request(v2_bp) time_blueprint(v2_bp, metric_queue) diff --git a/endpoints/verbs/__init__.py b/endpoints/verbs/__init__.py index 87e4fa644..d7c8e248b 100644 --- a/endpoints/verbs/__init__.py +++ b/endpoints/verbs/__init__.py @@ -18,7 +18,6 @@ from endpoints.v2.blob import BLOB_DIGEST_ROUTE from image.appc import AppCImageFormatter from image.docker.squashed import SquashedDockerImageFormatter from storage import Storage -from util.license import enforce_license_before_request from util.registry.filelike import wrap_with_handler from util.registry.queuefile import QueueFile from util.registry.queueprocess import QueueProcess @@ -29,7 +28,7 @@ from util.registry.torrent import (make_torrent, per_user_torrent_filename, publ logger = logging.getLogger(__name__) verbs = Blueprint('verbs', __name__) -enforce_license_before_request(license_validator, verbs) +license_validator.enforce_license_before_request(verbs) def _open_stream(formatter, namespace, repository, tag, derived_image_id, repo_image, handlers): diff --git a/util/license.py b/util/license.py index 62a0b97b6..85a77ff3c 100644 --- a/util/license.py +++ b/util/license.py @@ -10,7 +10,7 @@ from threading import Thread from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_public_key from dateutil import parser -from flask import abort +from flask import make_response import jwt @@ -167,13 +167,14 @@ class LicenseValidator(Thread): This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a synchronization primitive. """ - def __init__(self, license_path): + def __init__(self, license_path, *args, **kwargs): self._license_path = license_path # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that. self._license_is_expired = multiprocessing.Value(c_bool, True) - super(LicenseValidator, self).__init__() + super(LicenseValidator, self).__init__(*args, **kwargs) + self.daemon = True @property def expired(self): @@ -183,13 +184,15 @@ class LicenseValidator(Thread): try: with open(self._license_path) as f: current_license = decode_license(f.read()) - logger.debug('updating license expiration to %s', current_license.is_expired) - self._license_is_expired.value = current_license.is_expired - return current_license.is_expired + is_expired = current_license.is_expired + logger.debug('updating license expiration to %s', is_expired) + self._license_is_expired.value = is_expired except (IOError, LicenseError): logger.exception('failed to validate license') - self._license_is_expired.value = True - return True + is_expired = True + self._license_is_expired.value = is_expired + + return is_expired def run(self): logger.debug('Starting license validation thread') @@ -199,13 +202,18 @@ class LicenseValidator(Thread): logger.debug('waiting %d seconds before retrying to validate license', sleep_time) time.sleep(sleep_time) + def enforce_license_before_request(self, blueprint, response_func=None): + """ + Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the + license has become invalid, the client will receive a HTTP 402 response. + """ + if response_func is None: + def _response_func(): + return make_response('License has expired.', 402) + response_func = _response_func -def enforce_license_before_request(license_validator, blueprint): - """ - Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the - license has become invalid, the client will receive a HTTP 402 response. - """ - def _enforce_license(): - if license_validator.expired: - abort(402) - blueprint.before_request(_enforce_license) + def _enforce_license(): + if self.expired: + logger.debug('blocked interaction due to expired license') + return response_func() + blueprint.before_request(_enforce_license)