diff --git a/app.py b/app.py index 953af51d8..e89328fbc 100644 --- a/app.py +++ b/app.py @@ -31,19 +31,15 @@ from util.saas.analytics import Analytics from util.saas.useranalytics import UserAnalytics from util.saas.exceptionlog import Sentry from util.names import urn_generator +from util.oauth.services import GoogleOAuthService, GithubOAuthService, GitLabOAuthService from util.config.configutil import generate_secret_key -from util.config.oauth import (GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthConfig, - DexOAuthConfig) from util.config.provider import get_config_provider from util.config.superusermanager import SuperUserManager from util.label_validator import LabelValidator -from util.license import LicenseValidator, LICENSE_FILENAME +from util.license import LicenseValidator from util.metrics.metricqueue import MetricQueue from util.metrics.prometheus import PrometheusPlugin -from util.names import urn_generator -from util.saas.analytics import Analytics from util.saas.cloudwatch import start_cloudwatch_sender -from util.saas.exceptionlog import Sentry from util.secscan.api import SecurityScannerAPI from util.security.instancekeys import InstanceKeys from util.security.signing import Signer @@ -204,13 +200,12 @@ license_validator.start() start_cloudwatch_sender(metric_queue, app) -github_login = GithubOAuthConfig(app.config, 'GITHUB_LOGIN_CONFIG') -github_trigger = GithubOAuthConfig(app.config, 'GITHUB_TRIGGER_CONFIG') -gitlab_trigger = GitLabOAuthConfig(app.config, 'GITLAB_TRIGGER_CONFIG') -google_login = GoogleOAuthConfig(app.config, 'GOOGLE_LOGIN_CONFIG') -dex_login = DexOAuthConfig(app.config, 'DEX_LOGIN_CONFIG') +github_login = GithubOAuthService(app.config, 'GITHUB_LOGIN_CONFIG') +github_trigger = GithubOAuthService(app.config, 'GITHUB_TRIGGER_CONFIG') +gitlab_trigger = GitLabOAuthService(app.config, 'GITLAB_TRIGGER_CONFIG') +google_login = GoogleOAuthService(app.config, 'GOOGLE_LOGIN_CONFIG') -oauth_apps = [github_login, github_trigger, gitlab_trigger, google_login, dex_login] +oauth_apps = [github_login, github_trigger, gitlab_trigger, google_login] image_replication_queue = WorkQueue(app.config['REPLICATION_QUEUE_NAME'], tf, has_namespace=False) dockerfile_build_queue = WorkQueue(app.config['DOCKERFILE_BUILD_QUEUE_NAME'], tf, @@ -240,7 +235,7 @@ model.config.register_image_cleanup_callback(secscan_api.cleanup_layers) @login_manager.user_loader def load_user(user_uuid): - logger.debug('User loader loading deferred user with uuid: %s' % user_uuid) + logger.debug('User loader loading deferred user with uuid: %s', user_uuid) return LoginWrappedDBUser(user_uuid) class LoginWrappedDBUser(UserMixin): diff --git a/endpoints/oauthlogin.py b/endpoints/oauthlogin.py index bbb5d86f5..98cca1839 100644 --- a/endpoints/oauthlogin.py +++ b/endpoints/oauthlogin.py @@ -7,7 +7,7 @@ from peewee import IntegrityError import features -from app import app, analytics, get_app_url, github_login, google_login, dex_login +from app import app, analytics, get_app_url, github_login, google_login from auth.process import require_session_login from data import model from endpoints.common import common_login, route_show_if @@ -281,14 +281,3 @@ def github_oauth_attach(): return redirect(url_for('web.user_view', path=user_obj.username, tab='external')) -def decode_user_jwt(token, oidc_provider): - try: - return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'], - audience=oidc_provider.client_id(), - issuer=oidc_provider.issuer) - except InvalidTokenError: - # Public key may have expired. Try to retrieve an updated public key and use it to decode. - return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'], - audience=oidc_provider.client_id(), - issuer=oidc_provider.issuer) - diff --git a/test/test_endpoints.py b/test/test_endpoints.py index 685946ddf..bf12038d3 100644 --- a/test/test_endpoints.py +++ b/test/test_endpoints.py @@ -218,50 +218,6 @@ class OAuthLoginTestCase(EndpointTestCase): self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github', 'someid', 'someusername') - def test_dex_oauth(self): - # TODO(jschorr): Add tests for invalid and expired keys. - - # Generate a public/private key pair for the OIDC transaction. - private_key = RSA.generate(2048) - jwk = RSAKey(key=private_key.publickey()).serialize() - token = jwt.encode({ - 'iss': 'https://oidcserver/', - 'aud': 'someclientid', - 'sub': 'someid', - 'exp': int(time.time()) + 60, - 'iat': int(time.time()), - 'nbf': int(time.time()), - 'email': 'someemail@example.com', - 'email_verified': True, - }, private_key.exportKey('PEM'), 'RS256') - - @urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration') - def wellknown_handler(url, _): - return py_json.dumps({ - 'authorization_endpoint': 'http://oidcserver/auth', - 'token_endpoint': 'http://oidcserver/token', - 'jwks_uri': 'http://oidcserver/keys', - }) - - @urlmatch(netloc=r'oidcserver', path='/token') - def account_handler(url, request): - if request.body.find("code=somecode") > 0: - return py_json.dumps({ - 'access_token': token, - }) - else: - return {'status_code': 400, 'content': '{"message": "Invalid code"}'} - - @urlmatch(netloc=r'oidcserver', path='/keys') - def keys_handler(_, __): - return py_json.dumps({ - "keys": [jwk], - }) - - with HTTMock(wellknown_handler, account_handler, keys_handler): - self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex', - 'someid', 'someemail') - class WebEndpointTestCase(EndpointTestCase): def test_index(self): diff --git a/util/oauth/__init__.py b/util/oauth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/util/oauth/base.py b/util/oauth/base.py new file mode 100644 index 000000000..9ede7cfdc --- /dev/null +++ b/util/oauth/base.py @@ -0,0 +1,63 @@ +class OAuthService(object): + """ A base class for defining an external service, exposed via OAuth. """ + def __init__(self, config, key_name): + self.key_name = key_name + self.config = config.get(key_name) or {} + + def service_name(self): + raise NotImplementedError + + def token_endpoint(self): + raise NotImplementedError + + def user_endpoint(self): + raise NotImplementedError + + def validate_client_id_and_secret(self, http_client, app_config): + raise NotImplementedError + + def client_id(self): + return self.config.get('CLIENT_ID') + + def client_secret(self): + return self.config.get('CLIENT_SECRET') + + def get_redirect_uri(self, app_config, redirect_suffix=''): + return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], + app_config['SERVER_HOSTNAME'], + self.service_name().lower(), + redirect_suffix) + + def exchange_code_for_token(self, app_config, http_client, code, form_encode=False, + redirect_suffix='', client_auth=False): + payload = { + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix) + } + + headers = { + 'Accept': 'application/json' + } + + auth = None + if client_auth: + auth = (self.client_id(), self.client_secret()) + else: + payload['client_id'] = self.client_id() + payload['client_secret'] = self.client_secret() + + token_url = self.token_endpoint() + if form_encode: + get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth) + else: + get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) + + if get_access_token.status_code / 100 != 2: + return None + + json_data = get_access_token.json() + if not json_data: + return None + + return json_data.get('access_token', None) diff --git a/util/config/oauth.py b/util/oauth/services.py similarity index 52% rename from util/config/oauth.py rename to util/oauth/services.py index 7676e7954..d5877fc14 100644 --- a/util/config/oauth.py +++ b/util/oauth/services.py @@ -1,87 +1,9 @@ -import urlparse -import json -import logging -import time - -from cachetools import TTLCache -from cachetools.func import lru_cache - -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.serialization import load_der_public_key - -from jwkest.jwk import KEYS from util import slash_join +from util.oauth.base import OAuthService -logger = logging.getLogger(__name__) - -class OAuthConfig(object): +class GithubOAuthService(OAuthService): def __init__(self, config, key_name): - self.key_name = key_name - self.config = config.get(key_name) or {} - - def service_name(self): - raise NotImplementedError - - def token_endpoint(self): - raise NotImplementedError - - def user_endpoint(self): - raise NotImplementedError - - def validate_client_id_and_secret(self, http_client, app_config): - raise NotImplementedError - - def client_id(self): - return self.config.get('CLIENT_ID') - - def client_secret(self): - return self.config.get('CLIENT_SECRET') - - def get_redirect_uri(self, app_config, redirect_suffix=''): - return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], - app_config['SERVER_HOSTNAME'], - self.service_name().lower(), - redirect_suffix) - - - def exchange_code_for_token(self, app_config, http_client, code, form_encode=False, - redirect_suffix='', client_auth=False): - payload = { - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix) - } - - headers = { - 'Accept': 'application/json' - } - - auth = None - if client_auth: - auth = (self.client_id(), self.client_secret()) - else: - payload['client_id'] = self.client_id() - payload['client_secret'] = self.client_secret() - - token_url = self.token_endpoint() - if form_encode: - get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth) - else: - get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) - - if get_access_token.status_code / 100 != 2: - return None - - json_data = get_access_token.json() - if not json_data: - return None - - return json_data.get('access_token', None) - - -class GithubOAuthConfig(OAuthConfig): - def __init__(self, config, key_name): - super(GithubOAuthConfig, self).__init__(config, key_name) + super(GithubOAuthService, self).__init__(config, key_name) def service_name(self): return 'GitHub' @@ -174,10 +96,9 @@ class GithubOAuthConfig(OAuthConfig): } - -class GoogleOAuthConfig(OAuthConfig): +class GoogleOAuthService(OAuthService): def __init__(self, config, key_name): - super(GoogleOAuthConfig, self).__init__(config, key_name) + super(GoogleOAuthService, self).__init__(config, key_name) def service_name(self): return 'Google' @@ -215,13 +136,16 @@ class GoogleOAuthConfig(OAuthConfig): } -class GitLabOAuthConfig(OAuthConfig): +class GitLabOAuthService(OAuthService): def __init__(self, config, key_name): - super(GitLabOAuthConfig, self).__init__(config, key_name) + super(GitLabOAuthService, self).__init__(config, key_name) def _endpoint(self): return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com') + def user_endpoint(self): + raise NotImplementedError + def api_endpoint(self): return self._endpoint() @@ -262,90 +186,3 @@ class GitLabOAuthConfig(OAuthConfig): 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'GITLAB_ENDPOINT': self._endpoint(), } - - -OIDC_WELLKNOWN = ".well-known/openid-configuration" -PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour - -class OIDCConfig(OAuthConfig): - def __init__(self, config, key_name): - super(OIDCConfig, self).__init__(config, key_name) - - self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) - self._config = config - self._http_client = config['HTTPCLIENT'] - - @lru_cache(maxsize=1) - def _oidc_config(self): - if self.config.get('OIDC_SERVER'): - return self._load_via_discovery(self._config.get('DEBUGGING', False)) - else: - return {} - - def _load_via_discovery(self, is_debugging): - oidc_server = self.config['OIDC_SERVER'] - if not oidc_server.startswith('https://') and not is_debugging: - raise Exception('OIDC server must be accessed over SSL') - - discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) - discovery = self._http_client.get(discovery_url, timeout=5) - - if discovery.status_code / 100 != 2: - raise Exception("Could not load OIDC discovery information") - - try: - return json.loads(discovery.text) - except ValueError: - logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) - raise Exception("Could not parse OIDC discovery information") - - def authorize_endpoint(self): - return self._oidc_config().get('authorization_endpoint', '') + '?' - - def token_endpoint(self): - return self._oidc_config().get('token_endpoint') - - def user_endpoint(self): - return None - - def validate_client_id_and_secret(self, http_client, app_config): - pass - - def get_public_config(self): - return { - 'CLIENT_ID': self.client_id(), - 'AUTHORIZE_ENDPOINT': self.authorize_endpoint() - } - - @property - def issuer(self): - return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) - - def get_public_key(self, force_refresh=False): - """ Retrieves the public key for this handler. """ - # If force_refresh is true, we expire all the items in the cache by setting the time to - # the current time + the expiration TTL. - if force_refresh: - self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) - - # Retrieve the public key from the cache. If the cache does not contain the public key, - # it will internally call _get_public_key to retrieve it and then save it. The None is - # a random key chose to be stored in the cache, and could be anything. - return self._public_key_cache[None] - - def _get_public_key(self, _): - """ Retrieves the public key for this handler. """ - keys_url = self._oidc_config()['jwks_uri'] - - keys = KEYS() - keys.load_from_url(keys_url) - - if not list(keys): - raise Exception('No keys provided by OIDC provider') - - rsa_key = list(keys)[0] - rsa_key.deserialize() - - # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing - # issues. - return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())