Some fixes and tests for v2 auth

Fixes #395
This commit is contained in:
Jake Moshenko 2015-09-10 12:24:33 -04:00
parent 5029ab62f6
commit 9c3ddf846f
7 changed files with 290 additions and 35 deletions

View file

@ -1,7 +1,7 @@
import logging
import re
from datetime import datetime, timedelta
from jsonschema import validate, ValidationError
from functools import wraps
from flask import request
from flask.ext.principal import identity_changed, Identity
@ -20,7 +20,45 @@ from util.security import strictjwt
logger = logging.getLogger(__name__)
TOKEN_REGEX = re.compile(r'Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)')
TOKEN_REGEX = re.compile(r'^Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)$')
ACCESS_SCHEMA = {
'type': 'array',
'description': 'List of access granted to the subject',
'items': {
'type': 'object',
'required': [
'type',
'name',
'actions',
],
'properties': {
'type': {
'type': 'string',
'description': 'We only allow repository permissions',
'enum': [
'repository',
],
},
'name': {
'type': 'string',
'description': 'The name of the repository for which we are receiving access'
},
'actions': {
'type': 'array',
'description': 'List of specific verbs which can be performed against repository',
'items': {
'type': 'string',
'enum': [
'push',
'pull',
],
},
},
},
},
}
class InvalidJWTException(Exception):
@ -36,7 +74,7 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
# Extract the jwt token from the header
match = TOKEN_REGEX.match(bearer_token)
if match is None or match.end() != len(bearer_token):
if match is None:
raise InvalidJWTException('Invalid bearer token format')
encoded = match.group(1)
@ -44,27 +82,31 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
# Load the JWT returned.
try:
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay',
issuer='token-issuer')
expected_issuer = app.config['JWT_AUTH_TOKEN_ISSUER']
audience = app.config['SERVER_HOSTNAME']
max_exp = strictjwt.exp_max_s_option(max_signed_s)
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience=audience,
issuer=expected_issuer, options=max_exp)
except strictjwt.InvalidTokenError:
logger.exception('Invalid token reason')
raise InvalidJWTException('Invalid token')
if not 'sub' in payload:
raise InvalidJWTException('Missing sub field in JWT')
# Verify that the expiration is no more than 300 seconds in the future.
if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s):
raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s)
username = payload['sub']
loaded_identity = Identity(username, 'signed_jwt')
# Process the grants from the payload
if 'access' in payload:
for grant in payload['access']:
if grant['type'] != 'repository':
continue
if 'access' in payload:
try:
validate(payload['access'], ACCESS_SCHEMA)
except ValidationError:
logger.exception('We should not be minting invalid credentials')
raise InvalidJWTException('Token contained invalid or malformed access grants')
for grant in payload['access']:
namespace, repo_name = parse_namespace_repository(grant['name'])
if 'push' in grant['actions']:
@ -88,7 +130,7 @@ def process_jwt_auth(func):
logger.debug('Called with params: %s, %s', args, kwargs)
auth = request.headers.get('authorization', '').strip()
if auth:
max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 300)
max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 3660)
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
public_key = load_public_key(certificate_file_path)

View file

@ -222,7 +222,8 @@ class DefaultConfig(object):
SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
# Registry v2 JWT Auth config
JWT_AUTH_MAX_FRESH_S = 60 * 5 # At most the JWT can be signed for 300s in the future
JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed for one hour, accounting for clock skew
JWT_AUTH_TOKEN_ISSUER = 'quay-test-issuer'
JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt'
JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure'

View file

@ -2,7 +2,6 @@ import logging
import json
import os
from datetime import datetime, timedelta
from data.users.federated import FederatedUsers, VerifiedCredentials
from util.security import strictjwt
@ -46,9 +45,11 @@ class ExternalJWTAuthN(FederatedUsers):
# Load the JWT returned.
encoded = result_data.get('token', '')
exp_limit_options = strictjwt.exp_max_s_option(self.max_fresh_s)
try:
payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'],
audience='quay.io/jwtauthn', issuer=self.issuer)
audience='quay.io/jwtauthn', issuer=self.issuer,
options=exp_limit_options)
except strictjwt.InvalidTokenError:
logger.exception('Exception when decoding returned JWT')
return (None, 'Invalid username or password')
@ -59,16 +60,6 @@ class ExternalJWTAuthN(FederatedUsers):
if not 'email' in payload:
raise Exception('Missing email field in JWT')
if not 'exp' in payload:
raise Exception('Missing exp field in JWT')
# Verify that the expiration is no more than self.max_fresh_s seconds in the future.
expiration = datetime.utcfromtimestamp(payload['exp'])
if expiration > datetime.utcnow() + timedelta(seconds=self.max_fresh_s):
logger.debug('Payload expiration is outside of the %s second window: %s', self.max_fresh_s,
payload['exp'])
return (None, 'Invalid username or password')
# Parse out the username and email.
return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None)

View file

@ -6,6 +6,7 @@ import logging
from flask import Blueprint, make_response, url_for, request, jsonify
from functools import wraps
from urlparse import urlparse
from util import get_app_url
from app import metric_queue
from endpoints.decorators import anon_protect, anon_allowed
@ -69,12 +70,11 @@ def v2_support_enabled():
if get_grant_user_context() is None:
response = make_response('true', 401)
realm_hostname = urlparse(request.url).netloc
realm_auth_path = url_for('v2.generate_registry_jwt')
scheme = app.config['PREFERRED_URL_SCHEME']
authenticate = 'Bearer realm="{0}://{1}{2}",service="quay"'.format(scheme, realm_hostname,
realm_auth_path)
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(app.config),
realm_auth_path,
app.config['SERVER_HOSTNAME'])
response.headers['WWW-Authenticate'] = authenticate
response.headers['Docker-Distribution-API-Version'] = 'registry/2.0'

View file

@ -24,9 +24,11 @@ from endpoints.decorators import anon_protect
logger = logging.getLogger(__name__)
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
SCOPE_REGEX = re.compile(
r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$'
)
ANONYMOUS_SUB = '(anonymous)'
@lru_cache(maxsize=1)
@ -89,7 +91,6 @@ def generate_registry_jwt():
not model.repository.repository_is_public(namespace, reponame)):
abort(403)
access.append({
'type': 'repository',
'name': namespace_and_repo,
@ -97,11 +98,12 @@ def generate_registry_jwt():
})
token_data = {
'iss': 'token-issuer',
'iss': app.config['JWT_AUTH_TOKEN_ISSUER'],
'aud': audience_param,
'nbf': int(time.time()),
'exp': int(time.time() + 60),
'sub': user.username if user else '(anonymous)',
'iat': int(time.time()),
'exp': int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': user.username if user else ANONYMOUS_SUB,
'access': access,
}

View file

@ -0,0 +1,185 @@
import unittest
import time
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from app import app
from endpoints.v2.v2auth import (TOKEN_VALIDITY_LIFETIME_S, load_certificate_bytes,
load_private_key, ANONYMOUS_SUB)
from auth.jwt_auth import identity_from_bearer_token, load_public_key, InvalidJWTException
from util.morecollections import AttrDict
TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
TEST_USER = AttrDict({'username': 'joeuser'})
MAX_SIGNED_S = 3660
class TestRegistryV2Auth(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestRegistryV2Auth, self).__init__(*args, **kwargs)
self.public_key = None
def setUp(self):
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
self.public_key = load_public_key(certificate_file_path)
def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
exp=None, nbf=None, iss=app.config['JWT_AUTH_TOKEN_ISSUER']):
return {
'iss': iss,
'aud': audience,
'nbf': nbf if nbf is not None else int(time.time()),
'iat': iat if iat is not None else int(time.time()),
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
'sub': user.username if user else ANONYMOUS_SUB,
'access': access,
}
def _generate_token(self, token_data):
certificate = load_certificate_bytes(app.config['JWT_AUTH_CERTIFICATE_PATH'])
token_headers = {
'x5c': [certificate],
}
private_key = load_private_key(app.config['JWT_AUTH_PRIVATE_KEY_PATH'])
token_data = jwt.encode(token_data, private_key, 'RS256', headers=token_headers)
return 'Bearer {0}'.format(token_data)
def _parse_token(self, token):
return identity_from_bearer_token(token, MAX_SIGNED_S, self.public_key)
def _generate_public_key(self):
key = rsa.generate_private_key(
public_exponent=65537,
key_size=1024,
backend=default_backend()
)
return key.public_key()
def test_accepted_token(self):
token = self._generate_token(self._generate_token_data())
identity = self._parse_token(token)
self.assertEqual(identity.id, TEST_USER.username)
self.assertEqual(0, len(identity.provides))
anon_token = self._generate_token(self._generate_token_data(user=None))
anon_identity = self._parse_token(anon_token)
self.assertEqual(anon_identity.id, ANONYMOUS_SUB)
self.assertEqual(0, len(identity.provides))
def test_token_with_access(self):
access = [
{
'type': 'repository',
'name': 'somens/somerepo',
'actions': ['pull', 'push'],
}
]
token = self._generate_token(self._generate_token_data(access=access))
identity = self._parse_token(token)
self.assertEqual(identity.id, TEST_USER.username)
self.assertEqual(1, len(identity.provides))
def test_malformed_access(self):
access = [
{
'toipe': 'repository',
'namesies': 'somens/somerepo',
'akshuns': ['pull', 'push'],
}
]
token = self._generate_token(self._generate_token_data(access=access))
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
def test_bad_signature(self):
token = self._generate_token(self._generate_token_data())
other_public_key = self._generate_public_key()
with self.assertRaises(InvalidJWTException):
identity_from_bearer_token(token, MAX_SIGNED_S, other_public_key)
def test_audience(self):
token_data = self._generate_token_data(audience='someotherapp')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('aud')
no_aud = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_aud)
def test_nbf(self):
future = int(time.time()) + 60
token_data = self._generate_token_data(nbf=future)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('nbf')
no_nbf_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_nbf_token)
def test_iat(self):
future = int(time.time()) + 60
token_data = self._generate_token_data(iat=future)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('iat')
no_iat_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_iat_token)
def test_exp(self):
too_far = int(time.time()) + MAX_SIGNED_S * 2
token_data = self._generate_token_data(exp=too_far)
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
past = int(time.time()) - 60
token_data['exp'] = past
expired_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(expired_token)
token_data.pop('exp')
no_exp_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_exp_token)
def test_no_sub(self):
token_data = self._generate_token_data()
token_data.pop('sub')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
def test_iss(self):
token_data = self._generate_token_data(iss='badissuer')
token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(token)
token_data.pop('iss')
no_iss_token = self._generate_token(token_data)
with self.assertRaises(InvalidJWTException):
self._parse_token(no_iss_token)
if __name__ == '__main__':
import logging
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View file

@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from jwt import PyJWT
from jwt.exceptions import (
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
@ -14,8 +15,41 @@ class StrictJWT(PyJWT):
'require_exp': True,
'require_iat': True,
'require_nbf': True,
'exp_max_s': None,
})
return defaults
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
if options.get('exp_max_s') is not None:
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
raise ValueError('exp_max_s option implies verify_expiration')
options['verify_exp'] = True
# Do all of the other checks
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
if 'exp' in payload and options.get('exp_max_s') is not None:
# Validate that the expiration was not more than exp_max_s seconds after the issue time
# or in the absense of an issue time, more than exp_max_s in the future from now
# This will work because the parent method already checked the type of exp
expiration = datetime.utcfromtimestamp(int(payload['exp']))
max_signed_s = options.get('exp_max_s')
start_time = datetime.utcnow()
if 'iat' in payload:
start_time = datetime.utcfromtimestamp(int(payload['iat']))
if expiration > start_time + timedelta(seconds=max_signed_s):
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
start_time)
def exp_max_s_option(max_exp_s):
return {
'exp_max_s': max_exp_s,
}
decode = StrictJWT().decode