Refactor and rename the standard OAuth services
This commit is contained in:
parent
bee2551dc2
commit
4755d08677
6 changed files with 82 additions and 242 deletions
21
app.py
21
app.py
|
@ -31,19 +31,15 @@ from util.saas.analytics import Analytics
|
||||||
from util.saas.useranalytics import UserAnalytics
|
from util.saas.useranalytics import UserAnalytics
|
||||||
from util.saas.exceptionlog import Sentry
|
from util.saas.exceptionlog import Sentry
|
||||||
from util.names import urn_generator
|
from util.names import urn_generator
|
||||||
|
from util.oauth.services import GoogleOAuthService, GithubOAuthService, GitLabOAuthService
|
||||||
from util.config.configutil import generate_secret_key
|
from util.config.configutil import generate_secret_key
|
||||||
from util.config.oauth import (GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthConfig,
|
|
||||||
DexOAuthConfig)
|
|
||||||
from util.config.provider import get_config_provider
|
from util.config.provider import get_config_provider
|
||||||
from util.config.superusermanager import SuperUserManager
|
from util.config.superusermanager import SuperUserManager
|
||||||
from util.label_validator import LabelValidator
|
from util.label_validator import LabelValidator
|
||||||
from util.license import LicenseValidator, LICENSE_FILENAME
|
from util.license import LicenseValidator
|
||||||
from util.metrics.metricqueue import MetricQueue
|
from util.metrics.metricqueue import MetricQueue
|
||||||
from util.metrics.prometheus import PrometheusPlugin
|
from util.metrics.prometheus import PrometheusPlugin
|
||||||
from util.names import urn_generator
|
|
||||||
from util.saas.analytics import Analytics
|
|
||||||
from util.saas.cloudwatch import start_cloudwatch_sender
|
from util.saas.cloudwatch import start_cloudwatch_sender
|
||||||
from util.saas.exceptionlog import Sentry
|
|
||||||
from util.secscan.api import SecurityScannerAPI
|
from util.secscan.api import SecurityScannerAPI
|
||||||
from util.security.instancekeys import InstanceKeys
|
from util.security.instancekeys import InstanceKeys
|
||||||
from util.security.signing import Signer
|
from util.security.signing import Signer
|
||||||
|
@ -204,13 +200,12 @@ license_validator.start()
|
||||||
|
|
||||||
start_cloudwatch_sender(metric_queue, app)
|
start_cloudwatch_sender(metric_queue, app)
|
||||||
|
|
||||||
github_login = GithubOAuthConfig(app.config, 'GITHUB_LOGIN_CONFIG')
|
github_login = GithubOAuthService(app.config, 'GITHUB_LOGIN_CONFIG')
|
||||||
github_trigger = GithubOAuthConfig(app.config, 'GITHUB_TRIGGER_CONFIG')
|
github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG')
|
||||||
gitlab_trigger = GitLabOAuthConfig(app.config, 'GITLAB_TRIGGER_CONFIG')
|
gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG')
|
||||||
google_login = GoogleOAuthConfig(app.config, 'GOOGLE_LOGIN_CONFIG')
|
google_login = GoogleOAuthService(app.config, 'GOOGLE_LOGIN_CONFIG')
|
||||||
dex_login = DexOAuthConfig(app.config, 'DEX_LOGIN_CONFIG')
|
|
||||||
|
|
||||||
oauth_apps = [github_login, github_trigger, gitlab_trigger, google_login, dex_login]
|
oauth_apps = [github_login, github_trigger, gitlab_trigger, google_login]
|
||||||
|
|
||||||
image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf, has_namespace=False)
|
image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf, has_namespace=False)
|
||||||
dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf,
|
dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf,
|
||||||
|
@ -240,7 +235,7 @@ model.config.register_image_cleanup_callback(secscan_api.cleanup_layers)
|
||||||
|
|
||||||
@login_manager.user_loader
|
@login_manager.user_loader
|
||||||
def load_user(user_uuid):
|
def load_user(user_uuid):
|
||||||
logger.debug('User loader loading deferred user with uuid: %s' % user_uuid)
|
logger.debug('User loader loading deferred user with uuid: %s', user_uuid)
|
||||||
return LoginWrappedDBUser(user_uuid)
|
return LoginWrappedDBUser(user_uuid)
|
||||||
|
|
||||||
class LoginWrappedDBUser(UserMixin):
|
class LoginWrappedDBUser(UserMixin):
|
||||||
|
|
|
@ -7,7 +7,7 @@ from peewee import IntegrityError
|
||||||
|
|
||||||
import features
|
import features
|
||||||
|
|
||||||
from app import app, analytics, get_app_url, github_login, google_login, dex_login
|
from app import app, analytics, get_app_url, github_login, google_login
|
||||||
from auth.process import require_session_login
|
from auth.process import require_session_login
|
||||||
from data import model
|
from data import model
|
||||||
from endpoints.common import common_login, route_show_if
|
from endpoints.common import common_login, route_show_if
|
||||||
|
@ -281,14 +281,3 @@ def github_oauth_attach():
|
||||||
return redirect(url_for('web.user_view', path=user_obj.username, tab='external'))
|
return redirect(url_for('web.user_view', path=user_obj.username, tab='external'))
|
||||||
|
|
||||||
|
|
||||||
def decode_user_jwt(token, oidc_provider):
|
|
||||||
try:
|
|
||||||
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
|
|
||||||
audience=oidc_provider.client_id(),
|
|
||||||
issuer=oidc_provider.issuer)
|
|
||||||
except InvalidTokenError:
|
|
||||||
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
|
||||||
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
|
|
||||||
audience=oidc_provider.client_id(),
|
|
||||||
issuer=oidc_provider.issuer)
|
|
||||||
|
|
||||||
|
|
|
@ -218,50 +218,6 @@ class OAuthLoginTestCase(EndpointTestCase):
|
||||||
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
|
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
|
||||||
'someid', 'someusername')
|
'someid', 'someusername')
|
||||||
|
|
||||||
def test_dex_oauth(self):
|
|
||||||
# TODO(jschorr): Add tests for invalid and expired keys.
|
|
||||||
|
|
||||||
# Generate a public/private key pair for the OIDC transaction.
|
|
||||||
private_key = RSA.generate(2048)
|
|
||||||
jwk = RSAKey(key=private_key.publickey()).serialize()
|
|
||||||
token = jwt.encode({
|
|
||||||
'iss': 'https://oidcserver/',
|
|
||||||
'aud': 'someclientid',
|
|
||||||
'sub': 'someid',
|
|
||||||
'exp': int(time.time()) + 60,
|
|
||||||
'iat': int(time.time()),
|
|
||||||
'nbf': int(time.time()),
|
|
||||||
'email': 'someemail@example.com',
|
|
||||||
'email_verified': True,
|
|
||||||
}, private_key.exportKey('PEM'), 'RS256')
|
|
||||||
|
|
||||||
@urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration')
|
|
||||||
def wellknown_handler(url, _):
|
|
||||||
return py_json.dumps({
|
|
||||||
'authorization_endpoint': 'http://oidcserver/auth',
|
|
||||||
'token_endpoint': 'http://oidcserver/token',
|
|
||||||
'jwks_uri': 'http://oidcserver/keys',
|
|
||||||
})
|
|
||||||
|
|
||||||
@urlmatch(netloc=r'oidcserver', path='/token')
|
|
||||||
def account_handler(url, request):
|
|
||||||
if request.body.find("code=somecode") > 0:
|
|
||||||
return py_json.dumps({
|
|
||||||
'access_token': token,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
|
||||||
|
|
||||||
@urlmatch(netloc=r'oidcserver', path='/keys')
|
|
||||||
def keys_handler(_, __):
|
|
||||||
return py_json.dumps({
|
|
||||||
"keys": [jwk],
|
|
||||||
})
|
|
||||||
|
|
||||||
with HTTMock(wellknown_handler, account_handler, keys_handler):
|
|
||||||
self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex',
|
|
||||||
'someid', 'someemail')
|
|
||||||
|
|
||||||
|
|
||||||
class WebEndpointTestCase(EndpointTestCase):
|
class WebEndpointTestCase(EndpointTestCase):
|
||||||
def test_index(self):
|
def test_index(self):
|
||||||
|
|
0
util/oauth/__init__.py
Normal file
0
util/oauth/__init__.py
Normal file
63
util/oauth/base.py
Normal file
63
util/oauth/base.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
class OAuthService(object):
|
||||||
|
""" A base class for defining an external service, exposed via OAuth. """
|
||||||
|
def __init__(self, config, key_name):
|
||||||
|
self.key_name = key_name
|
||||||
|
self.config = config.get(key_name) or {}
|
||||||
|
|
||||||
|
def service_name(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def token_endpoint(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def user_endpoint(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate_client_id_and_secret(self, http_client, app_config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def client_id(self):
|
||||||
|
return self.config.get('CLIENT_ID')
|
||||||
|
|
||||||
|
def client_secret(self):
|
||||||
|
return self.config.get('CLIENT_SECRET')
|
||||||
|
|
||||||
|
def get_redirect_uri(self, app_config, redirect_suffix=''):
|
||||||
|
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
|
||||||
|
app_config['SERVER_HOSTNAME'],
|
||||||
|
self.service_name().lower(),
|
||||||
|
redirect_suffix)
|
||||||
|
|
||||||
|
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
|
||||||
|
redirect_suffix='', client_auth=False):
|
||||||
|
payload = {
|
||||||
|
'code': code,
|
||||||
|
'grant_type': 'authorization_code',
|
||||||
|
'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
auth = None
|
||||||
|
if client_auth:
|
||||||
|
auth = (self.client_id(), self.client_secret())
|
||||||
|
else:
|
||||||
|
payload['client_id'] = self.client_id()
|
||||||
|
payload['client_secret'] = self.client_secret()
|
||||||
|
|
||||||
|
token_url = self.token_endpoint()
|
||||||
|
if form_encode:
|
||||||
|
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
|
||||||
|
else:
|
||||||
|
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
||||||
|
|
||||||
|
if get_access_token.status_code / 100 != 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_data = get_access_token.json()
|
||||||
|
if not json_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json_data.get('access_token', None)
|
|
@ -1,87 +1,9 @@
|
||||||
import urlparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
from cachetools import TTLCache
|
|
||||||
from cachetools.func import lru_cache
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
|
||||||
|
|
||||||
from jwkest.jwk import KEYS
|
|
||||||
from util import slash_join
|
from util import slash_join
|
||||||
|
from util.oauth.base import OAuthService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
class GithubOAuthService(OAuthService):
|
||||||
|
|
||||||
class OAuthConfig(object):
|
|
||||||
def __init__(self, config, key_name):
|
def __init__(self, config, key_name):
|
||||||
self.key_name = key_name
|
super(GithubOAuthService, self).__init__(config, key_name)
|
||||||
self.config = config.get(key_name) or {}
|
|
||||||
|
|
||||||
def service_name(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def token_endpoint(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def user_endpoint(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def validate_client_id_and_secret(self, http_client, app_config):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def client_id(self):
|
|
||||||
return self.config.get('CLIENT_ID')
|
|
||||||
|
|
||||||
def client_secret(self):
|
|
||||||
return self.config.get('CLIENT_SECRET')
|
|
||||||
|
|
||||||
def get_redirect_uri(self, app_config, redirect_suffix=''):
|
|
||||||
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
|
|
||||||
app_config['SERVER_HOSTNAME'],
|
|
||||||
self.service_name().lower(),
|
|
||||||
redirect_suffix)
|
|
||||||
|
|
||||||
|
|
||||||
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
|
|
||||||
redirect_suffix='', client_auth=False):
|
|
||||||
payload = {
|
|
||||||
'code': code,
|
|
||||||
'grant_type': 'authorization_code',
|
|
||||||
'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
'Accept': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
auth = None
|
|
||||||
if client_auth:
|
|
||||||
auth = (self.client_id(), self.client_secret())
|
|
||||||
else:
|
|
||||||
payload['client_id'] = self.client_id()
|
|
||||||
payload['client_secret'] = self.client_secret()
|
|
||||||
|
|
||||||
token_url = self.token_endpoint()
|
|
||||||
if form_encode:
|
|
||||||
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
|
|
||||||
else:
|
|
||||||
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
|
||||||
|
|
||||||
if get_access_token.status_code / 100 != 2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
json_data = get_access_token.json()
|
|
||||||
if not json_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return json_data.get('access_token', None)
|
|
||||||
|
|
||||||
|
|
||||||
class GithubOAuthConfig(OAuthConfig):
|
|
||||||
def __init__(self, config, key_name):
|
|
||||||
super(GithubOAuthConfig, self).__init__(config, key_name)
|
|
||||||
|
|
||||||
def service_name(self):
|
def service_name(self):
|
||||||
return 'GitHub'
|
return 'GitHub'
|
||||||
|
@ -174,10 +96,9 @@ class GithubOAuthConfig(OAuthConfig):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthService(OAuthService):
|
||||||
class GoogleOAuthConfig(OAuthConfig):
|
|
||||||
def __init__(self, config, key_name):
|
def __init__(self, config, key_name):
|
||||||
super(GoogleOAuthConfig, self).__init__(config, key_name)
|
super(GoogleOAuthService, self).__init__(config, key_name)
|
||||||
|
|
||||||
def service_name(self):
|
def service_name(self):
|
||||||
return 'Google'
|
return 'Google'
|
||||||
|
@ -215,13 +136,16 @@ class GoogleOAuthConfig(OAuthConfig):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GitLabOAuthConfig(OAuthConfig):
|
class GitLabOAuthService(OAuthService):
|
||||||
def __init__(self, config, key_name):
|
def __init__(self, config, key_name):
|
||||||
super(GitLabOAuthConfig, self).__init__(config, key_name)
|
super(GitLabOAuthService, self).__init__(config, key_name)
|
||||||
|
|
||||||
def _endpoint(self):
|
def _endpoint(self):
|
||||||
return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')
|
return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')
|
||||||
|
|
||||||
|
def user_endpoint(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def api_endpoint(self):
|
def api_endpoint(self):
|
||||||
return self._endpoint()
|
return self._endpoint()
|
||||||
|
|
||||||
|
@ -262,90 +186,3 @@ class GitLabOAuthConfig(OAuthConfig):
|
||||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
|
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
|
||||||
'GITLAB_ENDPOINT': self._endpoint(),
|
'GITLAB_ENDPOINT': self._endpoint(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
|
||||||
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
|
||||||
|
|
||||||
class OIDCConfig(OAuthConfig):
|
|
||||||
def __init__(self, config, key_name):
|
|
||||||
super(OIDCConfig, self).__init__(config, key_name)
|
|
||||||
|
|
||||||
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
|
||||||
self._config = config
|
|
||||||
self._http_client = config['HTTPCLIENT']
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def _oidc_config(self):
|
|
||||||
if self.config.get('OIDC_SERVER'):
|
|
||||||
return self._load_via_discovery(self._config.get('DEBUGGING', False))
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _load_via_discovery(self, is_debugging):
|
|
||||||
oidc_server = self.config['OIDC_SERVER']
|
|
||||||
if not oidc_server.startswith('https://') and not is_debugging:
|
|
||||||
raise Exception('OIDC server must be accessed over SSL')
|
|
||||||
|
|
||||||
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
|
||||||
discovery = self._http_client.get(discovery_url, timeout=5)
|
|
||||||
|
|
||||||
if discovery.status_code / 100 != 2:
|
|
||||||
raise Exception("Could not load OIDC discovery information")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return json.loads(discovery.text)
|
|
||||||
except ValueError:
|
|
||||||
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
|
||||||
raise Exception("Could not parse OIDC discovery information")
|
|
||||||
|
|
||||||
def authorize_endpoint(self):
|
|
||||||
return self._oidc_config().get('authorization_endpoint', '') + '?'
|
|
||||||
|
|
||||||
def token_endpoint(self):
|
|
||||||
return self._oidc_config().get('token_endpoint')
|
|
||||||
|
|
||||||
def user_endpoint(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def validate_client_id_and_secret(self, http_client, app_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_public_config(self):
|
|
||||||
return {
|
|
||||||
'CLIENT_ID': self.client_id(),
|
|
||||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint()
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def issuer(self):
|
|
||||||
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
|
|
||||||
|
|
||||||
def get_public_key(self, force_refresh=False):
|
|
||||||
""" Retrieves the public key for this handler. """
|
|
||||||
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
|
||||||
# the current time + the expiration TTL.
|
|
||||||
if force_refresh:
|
|
||||||
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
|
||||||
|
|
||||||
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
|
||||||
# it will internally call _get_public_key to retrieve it and then save it. The None is
|
|
||||||
# a random key chose to be stored in the cache, and could be anything.
|
|
||||||
return self._public_key_cache[None]
|
|
||||||
|
|
||||||
def _get_public_key(self, _):
|
|
||||||
""" Retrieves the public key for this handler. """
|
|
||||||
keys_url = self._oidc_config()['jwks_uri']
|
|
||||||
|
|
||||||
keys = KEYS()
|
|
||||||
keys.load_from_url(keys_url)
|
|
||||||
|
|
||||||
if not list(keys):
|
|
||||||
raise Exception('No keys provided by OIDC provider')
|
|
||||||
|
|
||||||
rsa_key = list(keys)[0]
|
|
||||||
rsa_key.deserialize()
|
|
||||||
|
|
||||||
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
|
||||||
# issues.
|
|
||||||
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
|
|
Reference in a new issue