diff --git a/util/config/oauth.py b/util/config/oauth.py index 6e3f6f078..44bf084f2 100644 --- a/util/config/oauth.py +++ b/util/config/oauth.py @@ -5,6 +5,10 @@ import time from cachetools import TTLCache from cachetools.func import lru_cache + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_der_public_key + from jwkest.jwk import KEYS from util import slash_join @@ -341,7 +345,10 @@ class OIDCConfig(OAuthConfig): rsa_key = list(keys)[0] rsa_key.deserialize() - return rsa_key.key.exportKey('PEM') + + # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing + # issues. + return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend()) class DexOAuthConfig(OIDCConfig):