Some fixes and tests for v2 auth

Fixes #395
This commit is contained in:
Jake Moshenko 2015-09-10 12:24:33 -04:00
parent 5029ab62f6
commit 9c3ddf846f
7 changed files with 290 additions and 35 deletions

View file

@ -1,7 +1,7 @@
import logging import logging
import re import re
from datetime import datetime, timedelta from jsonschema import validate, ValidationError
from functools import wraps from functools import wraps
from flask import request from flask import request
from flask.ext.principal import identity_changed, Identity from flask.ext.principal import identity_changed, Identity
@ -20,7 +20,45 @@ from util.security import strictjwt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKEN_REGEX = re.compile(r'Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)') TOKEN_REGEX = re.compile(r'^Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)$')
ACCESS_SCHEMA = {
'type': 'array',
'description': 'List of access granted to the subject',
'items': {
'type': 'object',
'required': [
'type',
'name',
'actions',
],
'properties': {
'type': {
'type': 'string',
'description': 'We only allow repository permissions',
'enum': [
'repository',
],
},
'name': {
'type': 'string',
'description': 'The name of the repository for which we are receiving access'
},
'actions': {
'type': 'array',
'description': 'List of specific verbs which can be performed against repository',
'items': {
'type': 'string',
'enum': [
'push',
'pull',
],
},
},
},
},
}
class InvalidJWTException(Exception): class InvalidJWTException(Exception):
@ -36,7 +74,7 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
# Extract the jwt token from the header # Extract the jwt token from the header
match = TOKEN_REGEX.match(bearer_token) match = TOKEN_REGEX.match(bearer_token)
if match is None or match.end() != len(bearer_token): if match is None:
raise InvalidJWTException('Invalid bearer token format') raise InvalidJWTException('Invalid bearer token format')
encoded = match.group(1) encoded = match.group(1)
@ -44,27 +82,31 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
# Load the JWT returned. # Load the JWT returned.
try: try:
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay', expected_issuer = app.config['JWT_AUTH_TOKEN_ISSUER']
issuer='token-issuer') audience = app.config['SERVER_HOSTNAME']
max_exp = strictjwt.exp_max_s_option(max_signed_s)
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience=audience,
issuer=expected_issuer, options=max_exp)
except strictjwt.InvalidTokenError: except strictjwt.InvalidTokenError:
logger.exception('Invalid token reason')
raise InvalidJWTException('Invalid token') raise InvalidJWTException('Invalid token')
if not 'sub' in payload: if not 'sub' in payload:
raise InvalidJWTException('Missing sub field in JWT') raise InvalidJWTException('Missing sub field in JWT')
# Verify that the expiration is no more than 300 seconds in the future.
if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s):
raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s)
username = payload['sub'] username = payload['sub']
loaded_identity = Identity(username, 'signed_jwt') loaded_identity = Identity(username, 'signed_jwt')
# Process the grants from the payload # Process the grants from the payload
if 'access' in payload:
for grant in payload['access']:
if grant['type'] != 'repository':
continue
if 'access' in payload:
try:
validate(payload['access'], ACCESS_SCHEMA)
except ValidationError:
logger.exception('We should not be minting invalid credentials')
raise InvalidJWTException('Token contained invalid or malformed access grants')
for grant in payload['access']:
namespace, repo_name = parse_namespace_repository(grant['name']) namespace, repo_name = parse_namespace_repository(grant['name'])
if 'push' in grant['actions']: if 'push' in grant['actions']:
@ -88,7 +130,7 @@ def process_jwt_auth(func):
logger.debug('Called with params: %s, %s', args, kwargs) logger.debug('Called with params: %s, %s', args, kwargs)
auth = request.headers.get('authorization', '').strip() auth = request.headers.get('authorization', '').strip()
if auth: if auth:
max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 300) max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 3660)
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH'] certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
public_key = load_public_key(certificate_file_path) public_key = load_public_key(certificate_file_path)

View file

@ -222,7 +222,8 @@ class DefaultConfig(object):
SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
# Registry v2 JWT Auth config # Registry v2 JWT Auth config
JWT_AUTH_MAX_FRESH_S = 60 * 5 # At most the JWT can be signed for 300s in the future JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed for one hour, accounting for clock skew
JWT_AUTH_TOKEN_ISSUER = 'quay-test-issuer'
JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt' JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt'
JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure' JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure'

View file

@ -2,7 +2,6 @@ import logging
import json import json
import os import os
from datetime import datetime, timedelta
from data.users.federated import FederatedUsers, VerifiedCredentials from data.users.federated import FederatedUsers, VerifiedCredentials
from util.security import strictjwt from util.security import strictjwt
@ -46,9 +45,11 @@ class ExternalJWTAuthN(FederatedUsers):
# Load the JWT returned. # Load the JWT returned.
encoded = result_data.get('token', '') encoded = result_data.get('token', '')
exp_limit_options = strictjwt.exp_max_s_option(self.max_fresh_s)
try: try:
payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'], payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'],
audience='quay.io/jwtauthn', issuer=self.issuer) audience='quay.io/jwtauthn', issuer=self.issuer,
options=exp_limit_options)
except strictjwt.InvalidTokenError: except strictjwt.InvalidTokenError:
logger.exception('Exception when decoding returned JWT') logger.exception('Exception when decoding returned JWT')
return (None, 'Invalid username or password') return (None, 'Invalid username or password')
@ -59,16 +60,6 @@ class ExternalJWTAuthN(FederatedUsers):
if not 'email' in payload: if not 'email' in payload:
raise Exception('Missing email field in JWT') raise Exception('Missing email field in JWT')
if not 'exp' in payload:
raise Exception('Missing exp field in JWT')
# Verify that the expiration is no more than self.max_fresh_s seconds in the future.
expiration = datetime.utcfromtimestamp(payload['exp'])
if expiration > datetime.utcnow() + timedelta(seconds=self.max_fresh_s):
logger.debug('Payload expiration is outside of the %s second window: %s', self.max_fresh_s,
payload['exp'])
return (None, 'Invalid username or password')
# Parse out the username and email. # Parse out the username and email.
return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None) return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None)

View file

@ -6,6 +6,7 @@ import logging
from flask import Blueprint, make_response, url_for, request, jsonify from flask import Blueprint, make_response, url_for, request, jsonify
from functools import wraps from functools import wraps
from urlparse import urlparse from urlparse import urlparse
from util import get_app_url
from app import metric_queue from app import metric_queue
from endpoints.decorators import anon_protect, anon_allowed from endpoints.decorators import anon_protect, anon_allowed
@ -69,12 +70,11 @@ def v2_support_enabled():
if get_grant_user_context() is None: if get_grant_user_context() is None:
response = make_response('true', 401) response = make_response('true', 401)
realm_hostname = urlparse(request.url).netloc
realm_auth_path = url_for('v2.generate_registry_jwt') realm_auth_path = url_for('v2.generate_registry_jwt')
scheme = app.config['PREFERRED_URL_SCHEME']
authenticate = 'Bearer realm="{0}://{1}{2}",service="quay"'.format(scheme, realm_hostname, authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(app.config),
realm_auth_path) realm_auth_path,
app.config['SERVER_HOSTNAME'])
response.headers['WWW-Authenticate'] = authenticate response.headers['WWW-Authenticate'] = authenticate
response.headers['Docker-Distribution-API-Version'] = 'registry/2.0' response.headers['Docker-Distribution-API-Version'] = 'registry/2.0'

View file

@ -24,9 +24,11 @@ from endpoints.decorators import anon_protect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
SCOPE_REGEX = re.compile( SCOPE_REGEX = re.compile(
r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$' r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$'
) )
ANONYMOUS_SUB = '(anonymous)'
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@ -89,7 +91,6 @@ def generate_registry_jwt():
not model.repository.repository_is_public(namespace, reponame)): not model.repository.repository_is_public(namespace, reponame)):
abort(403) abort(403)
access.append({ access.append({
'type': 'repository', 'type': 'repository',
'name': namespace_and_repo, 'name': namespace_and_repo,
@ -97,11 +98,12 @@ def generate_registry_jwt():
}) })
token_data = { token_data = {
'iss': 'token-issuer', 'iss': app.config['JWT_AUTH_TOKEN_ISSUER'],
'aud': audience_param, 'aud': audience_param,
'nbf': int(time.time()), 'nbf': int(time.time()),
'exp': int(time.time() + 60), 'iat': int(time.time()),
'sub': user.username if user else '(anonymous)', 'exp': int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': user.username if user else ANONYMOUS_SUB,
'access': access, 'access': access,
} }

View file

@ -0,0 +1,185 @@
import unittest
import time
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from app import app
from endpoints.v2.v2auth import (TOKEN_VALIDITY_LIFETIME_S, load_certificate_bytes,
load_private_key, ANONYMOUS_SUB)
from auth.jwt_auth import identity_from_bearer_token, load_public_key, InvalidJWTException
from util.morecollections import AttrDict
TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
TEST_USER = AttrDict({'username': 'joeuser'})
MAX_SIGNED_S = 3660
class TestRegistryV2Auth(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestRegistryV2Auth, self).__init__(*args, **kwargs)
self.public_key = None
def setUp(self):
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
self.public_key = load_public_key(certificate_file_path)
def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
exp=None, nbf=None, iss=app.config['JWT_AUTH_TOKEN_ISSUER']):
return {
'iss': iss,
'aud': audience,
'nbf': nbf if nbf is not None else int(time.time()),
'iat': iat if iat is not None else int(time.time()),
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': user.username if user else ANONYMOUS_SUB,
'access': access,
}
def _generate_token(self, token_data):
certificate = load_certificate_bytes(app.config['JWT_AUTH_CERTIFICATE_PATH'])
token_headers = {
'x5c': [certificate],
}
private_key = load_private_key(app.config['JWT_AUTH_PRIVATE_KEY_PATH'])
token_data = jwt.encode(token_data, private_key, 'RS256', headers=token_headers)
return 'Bearer {0}'.format(token_data)
def _parse_token(self, token):
return identity_from_bearer_token(token, MAX_SIGNED_S, self.public_key)
def _generate_public_key(self):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=1024,
backend=default_backend()
)
return key.public_key()
def test_accepted_token(self):
token = self._generate_token(self._generate_token_data())
identity = self._parse_token(token)
self.assertEqual(identity.id, TEST_USER.username)
self.assertEqual(0, len(identity.provides))
anon_token = self._generate_token(self._generate_token_data(user=None))
anon_identity = self._parse_token(anon_token)
self.assertEqual(anon_identity.id, ANONYMOUS_SUB)
self.assertEqual(0, len(identity.provides))
def test_token_with_access(self):
access = [
{
'type': 'repository',
'name': 'somens/somerepo',
'actions': ['pull', 'push'],
}
]
token = self._generate_token(self._generate_token_data(access=access))
identity = self._parse_token(token)
self.assertEqual(identity.id, TEST_USER.username)
self.assertEqual(1, len(identity.provides))
def test_malformed_access(self):
access = [
{
'toipe': 'repository',
'namesies': 'somens/somerepo',
'akshuns': ['pull', 'push'],
}
]
token = self._generate_token(self._generate_token_data(access=access))
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
def test_bad_signature(self):
token = self._generate_token(self._generate_token_data())
other_public_key = self._generate_public_key()
with self.assertRaises(InvalidJWTException):
identity_from_bearer_token(token, MAX_SIGNED_S, other_public_key)
def test_audience(self):
token_data = self._generate_token_data(audience='someotherapp')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('aud')
no_aud = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_aud)
def test_nbf(self):
future = int(time.time()) + 60
token_data = self._generate_token_data(nbf=future)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('nbf')
no_nbf_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_nbf_token)
def test_iat(self):
future = int(time.time()) + 60
token_data = self._generate_token_data(iat=future)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('iat')
no_iat_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_iat_token)
def test_exp(self):
too_far = int(time.time()) + MAX_SIGNED_S * 2
token_data = self._generate_token_data(exp=too_far)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
past = int(time.time()) - 60
token_data['exp'] = past
expired_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(expired_token)
token_data.pop('exp')
no_exp_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_exp_token)
def test_no_sub(self):
token_data = self._generate_token_data()
token_data.pop('sub')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
def test_iss(self):
token_data = self._generate_token_data(iss='badissuer')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('iss')
no_iss_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_iss_token)
if __name__ == '__main__':
import logging
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View file

@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from jwt import PyJWT from jwt import PyJWT
from jwt.exceptions import ( from jwt.exceptions import (
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError, InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
@ -14,8 +15,41 @@ class StrictJWT(PyJWT):
'require_exp': True, 'require_exp': True,
'require_iat': True, 'require_iat': True,
'require_nbf': True, 'require_nbf': True,
'exp_max_s': None,
}) })
return defaults return defaults
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
if options.get('exp_max_s') is not None:
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
raise ValueError('exp_max_s option implies verify_expiration')
options['verify_exp'] = True
# Do all of the other checks
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
if 'exp' in payload and options.get('exp_max_s') is not None:
# Validate that the expiration was not more than exp_max_s seconds after the issue time
# or in the absense of an issue time, more than exp_max_s in the future from now
# This will work because the parent method already checked the type of exp
expiration = datetime.utcfromtimestamp(int(payload['exp']))
max_signed_s = options.get('exp_max_s')
start_time = datetime.utcnow()
if 'iat' in payload:
start_time = datetime.utcfromtimestamp(int(payload['iat']))
if expiration > start_time + timedelta(seconds=max_signed_s):
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
start_time)
def exp_max_s_option(max_exp_s):
return {
'exp_max_s': max_exp_s,
}
decode = StrictJWT().decode decode = StrictJWT().decode