diff --git a/auth/jwt_auth.py b/auth/jwt_auth.py index 9a4aa1bbe..e4d1b15a5 100644 --- a/auth/jwt_auth.py +++ b/auth/jwt_auth.py @@ -1,7 +1,7 @@ import logging import re -from datetime import datetime, timedelta +from jsonschema import validate, ValidationError from functools import wraps from flask import request from flask.ext.principal import identity_changed, Identity @@ -20,7 +20,45 @@ from util.security import strictjwt logger = logging.getLogger(__name__) -TOKEN_REGEX = re.compile(r'Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)') +TOKEN_REGEX = re.compile(r'^Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)$') + + +ACCESS_SCHEMA = { + 'type': 'array', + 'description': 'List of access granted to the subject', + 'items': { + 'type': 'object', + 'required': [ + 'type', + 'name', + 'actions', + ], + 'properties': { + 'type': { + 'type': 'string', + 'description': 'We only allow repository permissions', + 'enum': [ + 'repository', + ], + }, + 'name': { + 'type': 'string', + 'description': 'The name of the repository for which we are receiving access' + }, + 'actions': { + 'type': 'array', + 'description': 'List of specific verbs which can be performed against repository', + 'items': { + 'type': 'string', + 'enum': [ + 'push', + 'pull', + ], + }, + }, + }, + }, +} class InvalidJWTException(Exception): @@ -36,7 +74,7 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key): # Extract the jwt token from the header match = TOKEN_REGEX.match(bearer_token) - if match is None or match.end() != len(bearer_token): + if match is None: raise InvalidJWTException('Invalid bearer token format') encoded = match.group(1) @@ -44,27 +82,31 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key): # Load the JWT returned. try: - payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay', - issuer='token-issuer') + expected_issuer = app.config['JWT_AUTH_TOKEN_ISSUER'] + audience = app.config['SERVER_HOSTNAME'] + max_exp = strictjwt.exp_max_s_option(max_signed_s) + payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience=audience, + issuer=expected_issuer, options=max_exp) except strictjwt.InvalidTokenError: + logger.exception('Invalid token reason') raise InvalidJWTException('Invalid token') if not 'sub' in payload: raise InvalidJWTException('Missing sub field in JWT') - # Verify that the expiration is no more than 300 seconds in the future. - if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s): - raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s) - username = payload['sub'] loaded_identity = Identity(username, 'signed_jwt') # Process the grants from the payload - if 'access' in payload: - for grant in payload['access']: - if grant['type'] != 'repository': - continue + if 'access' in payload: + try: + validate(payload['access'], ACCESS_SCHEMA) + except ValidationError: + logger.exception('We should not be minting invalid credentials') + raise InvalidJWTException('Token contained invalid or malformed access grants') + + for grant in payload['access']: namespace, repo_name = parse_namespace_repository(grant['name']) if 'push' in grant['actions']: @@ -88,7 +130,7 @@ def process_jwt_auth(func): logger.debug('Called with params: %s, %s', args, kwargs) auth = request.headers.get('authorization', '').strip() if auth: - max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 300) + max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 3660) certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH'] public_key = load_public_key(certificate_file_path) diff --git a/config.py b/config.py index cd199a031..e27b07548 100644 --- a/config.py +++ b/config.py @@ -222,7 +222,8 @@ class DefaultConfig(object): SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull # Registry v2 JWT Auth config - JWT_AUTH_MAX_FRESH_S = 60 * 5 # At most the JWT can be signed for 300s in the future + JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed for one hour, accounting for clock skew + JWT_AUTH_TOKEN_ISSUER = 'quay-test-issuer' JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt' JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure' diff --git a/data/users/externaljwt.py b/data/users/externaljwt.py index ac29f22a1..55008aa9d 100644 --- a/data/users/externaljwt.py +++ b/data/users/externaljwt.py @@ -2,7 +2,6 @@ import logging import json import os -from datetime import datetime, timedelta from data.users.federated import FederatedUsers, VerifiedCredentials from util.security import strictjwt @@ -46,9 +45,11 @@ class ExternalJWTAuthN(FederatedUsers): # Load the JWT returned. encoded = result_data.get('token', '') + exp_limit_options = strictjwt.exp_max_s_option(self.max_fresh_s) try: payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'], - audience='quay.io/jwtauthn', issuer=self.issuer) + audience='quay.io/jwtauthn', issuer=self.issuer, + options=exp_limit_options) except strictjwt.InvalidTokenError: logger.exception('Exception when decoding returned JWT') return (None, 'Invalid username or password') @@ -59,16 +60,6 @@ class ExternalJWTAuthN(FederatedUsers): if not 'email' in payload: raise Exception('Missing email field in JWT') - if not 'exp' in payload: - raise Exception('Missing exp field in JWT') - - # Verify that the expiration is no more than self.max_fresh_s seconds in the future. - expiration = datetime.utcfromtimestamp(payload['exp']) - if expiration > datetime.utcnow() + timedelta(seconds=self.max_fresh_s): - logger.debug('Payload expiration is outside of the %s second window: %s', self.max_fresh_s, - payload['exp']) - return (None, 'Invalid username or password') - # Parse out the username and email. return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None) diff --git a/endpoints/v2/__init__.py b/endpoints/v2/__init__.py index b22af3ec2..e74e034ad 100644 --- a/endpoints/v2/__init__.py +++ b/endpoints/v2/__init__.py @@ -6,6 +6,7 @@ import logging from flask import Blueprint, make_response, url_for, request, jsonify from functools import wraps from urlparse import urlparse +from util import get_app_url from app import metric_queue from endpoints.decorators import anon_protect, anon_allowed @@ -69,12 +70,11 @@ def v2_support_enabled(): if get_grant_user_context() is None: response = make_response('true', 401) - realm_hostname = urlparse(request.url).netloc realm_auth_path = url_for('v2.generate_registry_jwt') - scheme = app.config['PREFERRED_URL_SCHEME'] - authenticate = 'Bearer realm="{0}://{1}{2}",service="quay"'.format(scheme, realm_hostname, - realm_auth_path) + authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(app.config), + realm_auth_path, + app.config['SERVER_HOSTNAME']) response.headers['WWW-Authenticate'] = authenticate response.headers['Docker-Distribution-API-Version'] = 'registry/2.0' diff --git a/endpoints/v2/v2auth.py b/endpoints/v2/v2auth.py index 3cb687128..a4129a23c 100644 --- a/endpoints/v2/v2auth.py +++ b/endpoints/v2/v2auth.py @@ -24,9 +24,11 @@ from endpoints.decorators import anon_protect logger = logging.getLogger(__name__) +TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour SCOPE_REGEX = re.compile( r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$' ) +ANONYMOUS_SUB = '(anonymous)' @lru_cache(maxsize=1) @@ -89,7 +91,6 @@ def generate_registry_jwt(): not model.repository.repository_is_public(namespace, reponame)): abort(403) - access.append({ 'type': 'repository', 'name': namespace_and_repo, @@ -97,11 +98,12 @@ def generate_registry_jwt(): }) token_data = { - 'iss': 'token-issuer', + 'iss': app.config['JWT_AUTH_TOKEN_ISSUER'], 'aud': audience_param, 'nbf': int(time.time()), - 'exp': int(time.time() + 60), - 'sub': user.username if user else '(anonymous)', + 'iat': int(time.time()), + 'exp': int(time.time() + TOKEN_VALIDITY_LIFETIME_S), + 'sub': user.username if user else ANONYMOUS_SUB, 'access': access, } diff --git a/test/test_registry_v2_auth.py b/test/test_registry_v2_auth.py new file mode 100644 index 000000000..f449935f3 --- /dev/null +++ b/test/test_registry_v2_auth.py @@ -0,0 +1,185 @@ +import unittest +import time +import jwt + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric import rsa + +from app import app +from endpoints.v2.v2auth import (TOKEN_VALIDITY_LIFETIME_S, load_certificate_bytes, + load_private_key, ANONYMOUS_SUB) +from auth.jwt_auth import identity_from_bearer_token, load_public_key, InvalidJWTException +from util.morecollections import AttrDict + + +TEST_AUDIENCE = app.config['SERVER_HOSTNAME'] +TEST_USER = AttrDict({'username': 'joeuser'}) +MAX_SIGNED_S = 3660 + +class TestRegistryV2Auth(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestRegistryV2Auth, self).__init__(*args, **kwargs) + self.public_key = None + + def setUp(self): + certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH'] + self.public_key = load_public_key(certificate_file_path) + + def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None, + exp=None, nbf=None, iss=app.config['JWT_AUTH_TOKEN_ISSUER']): + return { + 'iss': iss, + 'aud': audience, + 'nbf': nbf if nbf is not None else int(time.time()), + 'iat': iat if iat is not None else int(time.time()), + 'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S), + 'sub': user.username if user else ANONYMOUS_SUB, + 'access': access, + } + + def _generate_token(self, token_data): + + certificate = load_certificate_bytes(app.config['JWT_AUTH_CERTIFICATE_PATH']) + + token_headers = { + 'x5c': [certificate], + } + + private_key = load_private_key(app.config['JWT_AUTH_PRIVATE_KEY_PATH']) + token_data = jwt.encode(token_data, private_key, 'RS256', headers=token_headers) + return 'Bearer {0}'.format(token_data) + + def _parse_token(self, token): + return identity_from_bearer_token(token, MAX_SIGNED_S, self.public_key) + + def _generate_public_key(self): + key = rsa.generate_private_key( + public_exponent=65537, + key_size=1024, + backend=default_backend() + ) + return key.public_key() + + def test_accepted_token(self): + token = self._generate_token(self._generate_token_data()) + identity = self._parse_token(token) + self.assertEqual(identity.id, TEST_USER.username) + self.assertEqual(0, len(identity.provides)) + + anon_token = self._generate_token(self._generate_token_data(user=None)) + anon_identity = self._parse_token(anon_token) + self.assertEqual(anon_identity.id, ANONYMOUS_SUB) + self.assertEqual(0, len(identity.provides)) + + def test_token_with_access(self): + access = [ + { + 'type': 'repository', + 'name': 'somens/somerepo', + 'actions': ['pull', 'push'], + } + ] + token = self._generate_token(self._generate_token_data(access=access)) + identity = self._parse_token(token) + self.assertEqual(identity.id, TEST_USER.username) + self.assertEqual(1, len(identity.provides)) + + def test_malformed_access(self): + access = [ + { + 'toipe': 'repository', + 'namesies': 'somens/somerepo', + 'akshuns': ['pull', 'push'], + } + ] + token = self._generate_token(self._generate_token_data(access=access)) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + def test_bad_signature(self): + token = self._generate_token(self._generate_token_data()) + other_public_key = self._generate_public_key() + with self.assertRaises(InvalidJWTException): + identity_from_bearer_token(token, MAX_SIGNED_S, other_public_key) + + def test_audience(self): + token_data = self._generate_token_data(audience='someotherapp') + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + token_data.pop('aud') + no_aud = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(no_aud) + + def test_nbf(self): + future = int(time.time()) + 60 + token_data = self._generate_token_data(nbf=future) + + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + token_data.pop('nbf') + no_nbf_token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(no_nbf_token) + + def test_iat(self): + future = int(time.time()) + 60 + token_data = self._generate_token_data(iat=future) + + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + token_data.pop('iat') + no_iat_token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(no_iat_token) + + def test_exp(self): + too_far = int(time.time()) + MAX_SIGNED_S * 2 + token_data = self._generate_token_data(exp=too_far) + + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + past = int(time.time()) - 60 + token_data['exp'] = past + expired_token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(expired_token) + + token_data.pop('exp') + no_exp_token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(no_exp_token) + + def test_no_sub(self): + token_data = self._generate_token_data() + token_data.pop('sub') + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + def test_iss(self): + token_data = self._generate_token_data(iss='badissuer') + + token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(token) + + token_data.pop('iss') + no_iss_token = self._generate_token(token_data) + with self.assertRaises(InvalidJWTException): + self._parse_token(no_iss_token) + + +if __name__ == '__main__': + import logging + logging.basicConfig(level=logging.DEBUG) + unittest.main() + diff --git a/util/security/strictjwt.py b/util/security/strictjwt.py index 35f94444c..61bb61454 100644 --- a/util/security/strictjwt.py +++ b/util/security/strictjwt.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta from jwt import PyJWT from jwt.exceptions import ( InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError, @@ -14,8 +15,41 @@ class StrictJWT(PyJWT): 'require_exp': True, 'require_iat': True, 'require_nbf': True, + 'exp_max_s': None, }) return defaults + def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs): + if options.get('exp_max_s') is not None: + if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'): + raise ValueError('exp_max_s option implies verify_expiration') + + options['verify_exp'] = True + + # Do all of the other checks + super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs) + + if 'exp' in payload and options.get('exp_max_s') is not None: + # Validate that the expiration was not more than exp_max_s seconds after the issue time + # or in the absense of an issue time, more than exp_max_s in the future from now + + # This will work because the parent method already checked the type of exp + expiration = datetime.utcfromtimestamp(int(payload['exp'])) + max_signed_s = options.get('exp_max_s') + + start_time = datetime.utcnow() + if 'iat' in payload: + start_time = datetime.utcfromtimestamp(int(payload['iat'])) + + if expiration > start_time + timedelta(seconds=max_signed_s): + raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s, + start_time) + + +def exp_max_s_option(max_exp_s): + return { + 'exp_max_s': max_exp_s, + } + decode = StrictJWT().decode