Switch the license validator to use config_provider and have a test license

Fixes the broken tests currently which try (and fail) to read the license file
This commit is contained in:
Joseph Schorr 2016-10-18 11:44:13 -04:00
parent 2a7dbd3348
commit 67f828279d
4 changed files with 23 additions and 9 deletions

2
app.py
View file

@ -193,7 +193,7 @@ signer = Signer(app, config_provider)
instance_keys = InstanceKeys(app)
label_validator = LabelValidator(app)
license_validator = LicenseValidator(os.path.join(OVERRIDE_CONFIG_DIRECTORY, LICENSE_FILENAME))
license_validator = LicenseValidator(config_provider)
license_validator.start()
start_cloudwatch_sender(metric_queue, app)

View file

@ -4173,7 +4173,7 @@ class TestSuperUserLicense(ApiTestCase):
self._run_test('GET', 403, 'reader', None)
def test_get_devtable(self):
self._run_test('GET', 400, 'devtable', None)
self._run_test('GET', 200, 'devtable', None)
def test_put_anonymous(self):

View file

@ -5,6 +5,18 @@ from util.config.provider.baseprovider import BaseProvider
REAL_FILES = ['test/data/signing-private.gpg', 'test/data/signing-public.gpg']
class TestLicense(object):
@property
def subscription(self):
return {}
@property
def is_expired(self):
return False
def validate(self, config):
pass
class TestConfigProvider(BaseProvider):
""" Implementation of the config provider for testing. Everything is kept in-memory instead on
the real file system. """
@ -58,6 +70,9 @@ class TestConfigProvider(BaseProvider):
def requires_restart(self, app_config):
return False
def get_license(self):
return TestLicense()
def reset_for_test(self):
self._config['SUPER_USERS'] = ['devtable']
self.files = {}

View file

@ -190,8 +190,8 @@ class LicenseValidator(Thread):
This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
synchronization primitive.
"""
def __init__(self, license_path, *args, **kwargs):
self._license_path = license_path
def __init__(self, config_provider, *args, **kwargs):
self._config_provider = config_provider
# multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
self._license_is_expired = multiprocessing.Value(c_bool, True)
@ -205,8 +205,7 @@ class LicenseValidator(Thread):
def _check_expiration(self):
try:
with open(self._license_path) as f:
current_license = decode_license(f.read())
current_license = self._config_provider.get_license()
is_expired = current_license.is_expired
logger.debug('updating license expiration to %s', is_expired)
self._license_is_expired.value = is_expired