util.license: make bp-modification a method
This commit is contained in:
parent
6eb26d7998
commit
a42eb09a3e
4 changed files with 28 additions and 23 deletions
|
@ -10,7 +10,7 @@ from threading import Thread
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
||||
from dateutil import parser
|
||||
from flask import abort
|
||||
from flask import make_response
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -167,13 +167,14 @@ class LicenseValidator(Thread):
|
|||
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
|
||||
synchronization primitive.
|
||||
"""
|
||||
def __init__(self, license_path):
|
||||
def __init__(self, license_path, *args, **kwargs):
|
||||
self._license_path = license_path
|
||||
|
||||
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
|
||||
self._license_is_expired = multiprocessing.Value(c_bool, True)
|
||||
|
||||
super(LicenseValidator, self).__init__()
|
||||
super(LicenseValidator, self).__init__(*args, **kwargs)
|
||||
self.daemon = True
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
|
@ -183,13 +184,15 @@ class LicenseValidator(Thread):
|
|||
try:
|
||||
with open(self._license_path) as f:
|
||||
current_license = decode_license(f.read())
|
||||
logger.debug('updating license expiration to %s', current_license.is_expired)
|
||||
self._license_is_expired.value = current_license.is_expired
|
||||
return current_license.is_expired
|
||||
is_expired = current_license.is_expired
|
||||
logger.debug('updating license expiration to %s', is_expired)
|
||||
self._license_is_expired.value = is_expired
|
||||
except (IOError, LicenseError):
|
||||
logger.exception('failed to validate license')
|
||||
self._license_is_expired.value = True
|
||||
return True
|
||||
is_expired = True
|
||||
self._license_is_expired.value = is_expired
|
||||
|
||||
return is_expired
|
||||
|
||||
def run(self):
|
||||
logger.debug('Starting license validation thread')
|
||||
|
@ -199,13 +202,18 @@ class LicenseValidator(Thread):
|
|||
logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def enforce_license_before_request(self, blueprint, response_func=None):
|
||||
"""
|
||||
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
|
||||
license has become invalid, the client will receive a HTTP 402 response.
|
||||
"""
|
||||
if response_func is None:
|
||||
def _response_func():
|
||||
return make_response('License has expired.', 402)
|
||||
response_func = _response_func
|
||||
|
||||
def enforce_license_before_request(license_validator, blueprint):
|
||||
"""
|
||||
Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
|
||||
license has become invalid, the client will receive a HTTP 402 response.
|
||||
"""
|
||||
def _enforce_license():
|
||||
if license_validator.expired:
|
||||
abort(402)
|
||||
blueprint.before_request(_enforce_license)
|
||||
def _enforce_license():
|
||||
if self.expired:
|
||||
logger.debug('blocked interaction due to expired license')
|
||||
return response_func()
|
||||
blueprint.before_request(_enforce_license)
|
||||
|
|
Reference in a new issue