util.license: make bp-modification a method
This commit is contained in:
parent
6eb26d7998
commit
a42eb09a3e
4 changed files with 28 additions and 23 deletions
|
@ -2,12 +2,11 @@ from flask import Blueprint, make_response
|
|||
|
||||
from app import metric_queue, license_validator
|
||||
from endpoints.decorators import anon_protect, anon_allowed
|
||||
from util.license import enforce_license_before_request
|
||||
from util.metrics.metricqueue import time_blueprint
|
||||
|
||||
|
||||
v1_bp = Blueprint('v1', __name__)
|
||||
enforce_license_before_request(license_validator, v1_bp)
|
||||
license_validator.enforce_license_before_request(v1_bp)
|
||||
time_blueprint(v1_bp, metric_queue)
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from data import model
|
|||
from endpoints.decorators import anon_protect, anon_allowed
|
||||
from endpoints.v2.errors import V2RegistryException, Unauthorized
|
||||
from util.http import abort
|
||||
from util.license import enforce_license_before_request
|
||||
from util.metrics.metricqueue import time_blueprint
|
||||
from util.registry.dockerver import docker_version
|
||||
from util.pagination import encrypt_page_token, decrypt_page_token
|
||||
|
@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
v2_bp = Blueprint('v2', __name__)
|
||||
enforce_license_before_request(license_validator, v2_bp)
|
||||
license_validator.enforce_license_before_request(v2_bp)
|
||||
time_blueprint(v2_bp, metric_queue)
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from endpoints.v2.blob import BLOB_DIGEST_ROUTE
|
|||
from image.appc import AppCImageFormatter
|
||||
from image.docker.squashed import SquashedDockerImageFormatter
|
||||
from storage import Storage
|
||||
from util.license import enforce_license_before_request
|
||||
from util.registry.filelike import wrap_with_handler
|
||||
from util.registry.queuefile import QueueFile
|
||||
from util.registry.queueprocess import QueueProcess
|
||||
|
@ -29,7 +28,7 @@ from util.registry.torrent import (make_torrent, per_user_torrent_filename, publ
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
verbs = Blueprint('verbs', __name__)
|
||||
enforce_license_before_request(license_validator, verbs)
|
||||
license_validator.enforce_license_before_request(verbs)
|
||||
|
||||
|
||||
def _open_stream(formatter, namespace, repository, tag, derived_image_id, repo_image, handlers):
|
||||
|
|
|
@ -10,7 +10,7 @@ from threading import Thread
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
||||
from dateutil import parser
|
||||
from flask import abort
|
||||
from flask import make_response
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -167,13 +167,14 @@ class LicenseValidator(Thread):
|
|||
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
|
||||
synchronization primitive.
|
||||
"""
|
||||
def __init__(self, license_path):
|
||||
def __init__(self, license_path, *args, **kwargs):
|
||||
self._license_path = license_path
|
||||
|
||||
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
|
||||
self._license_is_expired = multiprocessing.Value(c_bool, True)
|
||||
|
||||
super(LicenseValidator, self).__init__()
|
||||
super(LicenseValidator, self).__init__(*args, **kwargs)
|
||||
self.daemon = True
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
|
@ -183,13 +184,15 @@ class LicenseValidator(Thread):
|
|||
try:
|
||||
with open(self._license_path) as f:
|
||||
current_license = decode_license(f.read())
|
||||
logger.debug('updating license expiration to %s', current_license.is_expired)
|
||||
self._license_is_expired.value = current_license.is_expired
|
||||
return current_license.is_expired
|
||||
is_expired = current_license.is_expired
|
||||
logger.debug('updating license expiration to %s', is_expired)
|
||||
self._license_is_expired.value = is_expired
|
||||
except (IOError, LicenseError):
|
||||
logger.exception('failed to validate license')
|
||||
self._license_is_expired.value = True
|
||||
return True
|
||||
is_expired = True
|
||||
self._license_is_expired.value = is_expired
|
||||
|
||||
return is_expired
|
||||
|
||||
def run(self):
|
||||
logger.debug('Starting license validation thread')
|
||||
|
@ -199,13 +202,18 @@ class LicenseValidator(Thread):
|
|||
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def enforce_license_before_request(license_validator, blueprint):
|
||||
def enforce_license_before_request(self, blueprint, response_func=None):
|
||||
"""
|
||||
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
|
||||
license has become invalid, the client will receive a HTTP 402 response.
|
||||
"""
|
||||
if response_func is None:
|
||||
def _response_func():
|
||||
return make_response('License has expired.', 402)
|
||||
response_func = _response_func
|
||||
|
||||
def _enforce_license():
|
||||
if license_validator.expired:
|
||||
abort(402)
|
||||
if self.expired:
|
||||
logger.debug('blocked interaction due to expired license')
|
||||
return response_func()
|
||||
blueprint.before_request(_enforce_license)
|
||||
|
|
Reference in a new issue