diff --git a/util/expiresdict.py b/util/expiresdict.py index 0eacf6a07..952cad26b 100644 --- a/util/expiresdict.py +++ b/util/expiresdict.py @@ -38,18 +38,20 @@ class ExpiresDict(object): return found.value # Otherwise the key has expired or was not found. Rebuild the cache and check it again. - self._rebuild() - found = self._items.get(key) - if found is None: + items = self._rebuild() + found_item = items.get(key) + if found_item is None: return default_value - return found.value + return found_item.value def __contains__(self, key): return self.get(key) is not None def _rebuild(self): - self._items = self._rebuilder() + items = self._rebuilder() + self._items = items + return items def set(self, key, value, expires=None): self._items[key] = ExpiresEntry(value, expires=expires) diff --git a/util/security/instancekeys.py b/util/security/instancekeys.py index 768c1e67b..75269552c 100644 --- a/util/security/instancekeys.py +++ b/util/security/instancekeys.py @@ -4,6 +4,23 @@ from util.expiresdict import ExpiresDict, ExpiresEntry from util.security import jwtutil +class CachingKey(object): + def __init__(self, service_key): + self._service_key = service_key + self._cached_public_key = None + + @property + def public_key(self): + cached_key = self._cached_public_key + if cached_key is not None: + return cached_key + + # Convert the JWK into a public key and cache it (since the conversion can take > 200ms). + public_key = jwtutil.jwk_dict_to_public_key(self._service_key.jwk) + self._cached_public_key = public_key + return public_key + + class InstanceKeys(object): """ InstanceKeys defines a helper class for interacting with the Quay instance service keys used for JWT signing of registry tokens as well as requests from Quay to other services @@ -12,23 +29,16 @@ class InstanceKeys(object): def __init__(self, app): self.app = app self.instance_keys = ExpiresDict(self._load_instance_keys) - self.public_keys = {} def clear_cache(self): """ Clears the cache of instance keys. """ self.instance_keys = ExpiresDict(self._load_instance_keys) - self.public_keys = {} def _load_instance_keys(self): # Load all the instance keys. keys = {} for key in model.service_keys.list_service_keys(self.service_name): - keys[key.kid] = ExpiresEntry(key, key.expiration_date) - - # Remove any expired or deleted keys from the public keys cache. - for key in dict(self.public_keys): - if key not in keys: - self.public_keys.pop(key) + keys[key.kid] = ExpiresEntry(CachingKey(key), key.expiration_date) return keys @@ -56,23 +66,11 @@ class InstanceKeys(object): def get_service_key_public_key(self, kid): """ Returns the public key associated with the given instance service key or None if none. """ - - # Note: We do the lookup via instance_keys *first* to ensure that if a key has expired, we - # don't use the entry in the public key cache. - service_key = self.instance_keys.get(kid) - if service_key is None: - # Remove the kid from the cache just to be sure. - self.public_keys.pop(kid, None) + caching_key = self.instance_keys.get(kid) + if caching_key is None: return None - public_key = self.public_keys.get(kid) - if public_key is not None: - return public_key - - # Convert the JWK into a public key and cache it (since the conversion can take > 200ms). - public_key = jwtutil.jwk_dict_to_public_key(service_key.jwk) - self.public_keys[kid] = public_key - return public_key + return caching_key.public_key def _load_file_contents(path):