parent
5029ab62f6
commit
9c3ddf846f
7 changed files with 290 additions and 35 deletions
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from jsonschema import validate, ValidationError
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask.ext.principal import identity_changed, Identity
|
from flask.ext.principal import identity_changed, Identity
|
||||||
|
@ -20,7 +20,45 @@ from util.security import strictjwt
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
TOKEN_REGEX = re.compile(r'Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)')
|
TOKEN_REGEX = re.compile(r'^Bearer (([a-zA-Z0-9+/]+\.)+[a-zA-Z0-9+-_/]+)$')
|
||||||
|
|
||||||
|
|
||||||
|
ACCESS_SCHEMA = {
|
||||||
|
'type': 'array',
|
||||||
|
'description': 'List of access granted to the subject',
|
||||||
|
'items': {
|
||||||
|
'type': 'object',
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'name',
|
||||||
|
'actions',
|
||||||
|
],
|
||||||
|
'properties': {
|
||||||
|
'type': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'We only allow repository permissions',
|
||||||
|
'enum': [
|
||||||
|
'repository',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
'name': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'The name of the repository for which we are receiving access'
|
||||||
|
},
|
||||||
|
'actions': {
|
||||||
|
'type': 'array',
|
||||||
|
'description': 'List of specific verbs which can be performed against repository',
|
||||||
|
'items': {
|
||||||
|
'type': 'string',
|
||||||
|
'enum': [
|
||||||
|
'push',
|
||||||
|
'pull',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class InvalidJWTException(Exception):
|
class InvalidJWTException(Exception):
|
||||||
|
@ -36,7 +74,7 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
|
||||||
|
|
||||||
# Extract the jwt token from the header
|
# Extract the jwt token from the header
|
||||||
match = TOKEN_REGEX.match(bearer_token)
|
match = TOKEN_REGEX.match(bearer_token)
|
||||||
if match is None or match.end() != len(bearer_token):
|
if match is None:
|
||||||
raise InvalidJWTException('Invalid bearer token format')
|
raise InvalidJWTException('Invalid bearer token format')
|
||||||
|
|
||||||
encoded = match.group(1)
|
encoded = match.group(1)
|
||||||
|
@ -44,27 +82,31 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
|
||||||
|
|
||||||
# Load the JWT returned.
|
# Load the JWT returned.
|
||||||
try:
|
try:
|
||||||
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay',
|
expected_issuer = app.config['JWT_AUTH_TOKEN_ISSUER']
|
||||||
issuer='token-issuer')
|
audience = app.config['SERVER_HOSTNAME']
|
||||||
|
max_exp = strictjwt.exp_max_s_option(max_signed_s)
|
||||||
|
payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience=audience,
|
||||||
|
issuer=expected_issuer, options=max_exp)
|
||||||
except strictjwt.InvalidTokenError:
|
except strictjwt.InvalidTokenError:
|
||||||
|
logger.exception('Invalid token reason')
|
||||||
raise InvalidJWTException('Invalid token')
|
raise InvalidJWTException('Invalid token')
|
||||||
|
|
||||||
if not 'sub' in payload:
|
if not 'sub' in payload:
|
||||||
raise InvalidJWTException('Missing sub field in JWT')
|
raise InvalidJWTException('Missing sub field in JWT')
|
||||||
|
|
||||||
# Verify that the expiration is no more than 300 seconds in the future.
|
|
||||||
if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s):
|
|
||||||
raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s)
|
|
||||||
|
|
||||||
username = payload['sub']
|
username = payload['sub']
|
||||||
loaded_identity = Identity(username, 'signed_jwt')
|
loaded_identity = Identity(username, 'signed_jwt')
|
||||||
|
|
||||||
# Process the grants from the payload
|
# Process the grants from the payload
|
||||||
if 'access' in payload:
|
|
||||||
for grant in payload['access']:
|
|
||||||
if grant['type'] != 'repository':
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
if 'access' in payload:
|
||||||
|
try:
|
||||||
|
validate(payload['access'], ACCESS_SCHEMA)
|
||||||
|
except ValidationError:
|
||||||
|
logger.exception('We should not be minting invalid credentials')
|
||||||
|
raise InvalidJWTException('Token contained invalid or malformed access grants')
|
||||||
|
|
||||||
|
for grant in payload['access']:
|
||||||
namespace, repo_name = parse_namespace_repository(grant['name'])
|
namespace, repo_name = parse_namespace_repository(grant['name'])
|
||||||
|
|
||||||
if 'push' in grant['actions']:
|
if 'push' in grant['actions']:
|
||||||
|
@ -88,7 +130,7 @@ def process_jwt_auth(func):
|
||||||
logger.debug('Called with params: %s, %s', args, kwargs)
|
logger.debug('Called with params: %s, %s', args, kwargs)
|
||||||
auth = request.headers.get('authorization', '').strip()
|
auth = request.headers.get('authorization', '').strip()
|
||||||
if auth:
|
if auth:
|
||||||
max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 300)
|
max_signature_seconds = app.config.get('JWT_AUTH_MAX_FRESH_S', 3660)
|
||||||
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
|
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
|
||||||
public_key = load_public_key(certificate_file_path)
|
public_key = load_public_key(certificate_file_path)
|
||||||
|
|
||||||
|
|
|
@ -222,7 +222,8 @@ class DefaultConfig(object):
|
||||||
SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
|
SIGNED_GRANT_EXPIRATION_SEC = 60 * 60 * 24 # One day to complete a push/pull
|
||||||
|
|
||||||
# Registry v2 JWT Auth config
|
# Registry v2 JWT Auth config
|
||||||
JWT_AUTH_MAX_FRESH_S = 60 * 5 # At most the JWT can be signed for 300s in the future
|
JWT_AUTH_MAX_FRESH_S = 60 * 60 + 60 # At most signed for one hour, accounting for clock skew
|
||||||
|
JWT_AUTH_TOKEN_ISSUER = 'quay-test-issuer'
|
||||||
JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt'
|
JWT_AUTH_CERTIFICATE_PATH = 'conf/selfsigned/jwt.crt'
|
||||||
JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure'
|
JWT_AUTH_PRIVATE_KEY_PATH = 'conf/selfsigned/jwt.key.insecure'
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from data.users.federated import FederatedUsers, VerifiedCredentials
|
from data.users.federated import FederatedUsers, VerifiedCredentials
|
||||||
from util.security import strictjwt
|
from util.security import strictjwt
|
||||||
|
|
||||||
|
@ -46,9 +45,11 @@ class ExternalJWTAuthN(FederatedUsers):
|
||||||
|
|
||||||
# Load the JWT returned.
|
# Load the JWT returned.
|
||||||
encoded = result_data.get('token', '')
|
encoded = result_data.get('token', '')
|
||||||
|
exp_limit_options = strictjwt.exp_max_s_option(self.max_fresh_s)
|
||||||
try:
|
try:
|
||||||
payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'],
|
payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'],
|
||||||
audience='quay.io/jwtauthn', issuer=self.issuer)
|
audience='quay.io/jwtauthn', issuer=self.issuer,
|
||||||
|
options=exp_limit_options)
|
||||||
except strictjwt.InvalidTokenError:
|
except strictjwt.InvalidTokenError:
|
||||||
logger.exception('Exception when decoding returned JWT')
|
logger.exception('Exception when decoding returned JWT')
|
||||||
return (None, 'Invalid username or password')
|
return (None, 'Invalid username or password')
|
||||||
|
@ -59,16 +60,6 @@ class ExternalJWTAuthN(FederatedUsers):
|
||||||
if not 'email' in payload:
|
if not 'email' in payload:
|
||||||
raise Exception('Missing email field in JWT')
|
raise Exception('Missing email field in JWT')
|
||||||
|
|
||||||
if not 'exp' in payload:
|
|
||||||
raise Exception('Missing exp field in JWT')
|
|
||||||
|
|
||||||
# Verify that the expiration is no more than self.max_fresh_s seconds in the future.
|
|
||||||
expiration = datetime.utcfromtimestamp(payload['exp'])
|
|
||||||
if expiration > datetime.utcnow() + timedelta(seconds=self.max_fresh_s):
|
|
||||||
logger.debug('Payload expiration is outside of the %s second window: %s', self.max_fresh_s,
|
|
||||||
payload['exp'])
|
|
||||||
return (None, 'Invalid username or password')
|
|
||||||
|
|
||||||
# Parse out the username and email.
|
# Parse out the username and email.
|
||||||
return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None)
|
return (VerifiedCredentials(username=payload['sub'], email=payload['email']), None)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
||||||
from flask import Blueprint, make_response, url_for, request, jsonify
|
from flask import Blueprint, make_response, url_for, request, jsonify
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from urlparse import urlparse
|
from urlparse import urlparse
|
||||||
|
from util import get_app_url
|
||||||
|
|
||||||
from app import metric_queue
|
from app import metric_queue
|
||||||
from endpoints.decorators import anon_protect, anon_allowed
|
from endpoints.decorators import anon_protect, anon_allowed
|
||||||
|
@ -69,12 +70,11 @@ def v2_support_enabled():
|
||||||
|
|
||||||
if get_grant_user_context() is None:
|
if get_grant_user_context() is None:
|
||||||
response = make_response('true', 401)
|
response = make_response('true', 401)
|
||||||
realm_hostname = urlparse(request.url).netloc
|
|
||||||
realm_auth_path = url_for('v2.generate_registry_jwt')
|
realm_auth_path = url_for('v2.generate_registry_jwt')
|
||||||
scheme = app.config['PREFERRED_URL_SCHEME']
|
|
||||||
|
|
||||||
authenticate = 'Bearer realm="{0}://{1}{2}",service="quay"'.format(scheme, realm_hostname,
|
authenticate = 'Bearer realm="{0}{1}",service="{2}"'.format(get_app_url(app.config),
|
||||||
realm_auth_path)
|
realm_auth_path,
|
||||||
|
app.config['SERVER_HOSTNAME'])
|
||||||
response.headers['WWW-Authenticate'] = authenticate
|
response.headers['WWW-Authenticate'] = authenticate
|
||||||
|
|
||||||
response.headers['Docker-Distribution-API-Version'] = 'registry/2.0'
|
response.headers['Docker-Distribution-API-Version'] = 'registry/2.0'
|
||||||
|
|
|
@ -24,9 +24,11 @@ from endpoints.decorators import anon_protect
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TOKEN_VALIDITY_LIFETIME_S = 60 * 60 # 1 hour
|
||||||
SCOPE_REGEX = re.compile(
|
SCOPE_REGEX = re.compile(
|
||||||
r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$'
|
r'^repository:([\.a-zA-Z0-9_\-]+/[\.a-zA-Z0-9_\-]+):(((push|pull|\*),)*(push|pull|\*))$'
|
||||||
)
|
)
|
||||||
|
ANONYMOUS_SUB = '(anonymous)'
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
@ -89,7 +91,6 @@ def generate_registry_jwt():
|
||||||
not model.repository.repository_is_public(namespace, reponame)):
|
not model.repository.repository_is_public(namespace, reponame)):
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
|
|
||||||
access.append({
|
access.append({
|
||||||
'type': 'repository',
|
'type': 'repository',
|
||||||
'name': namespace_and_repo,
|
'name': namespace_and_repo,
|
||||||
|
@ -97,11 +98,12 @@ def generate_registry_jwt():
|
||||||
})
|
})
|
||||||
|
|
||||||
token_data = {
|
token_data = {
|
||||||
'iss': 'token-issuer',
|
'iss': app.config['JWT_AUTH_TOKEN_ISSUER'],
|
||||||
'aud': audience_param,
|
'aud': audience_param,
|
||||||
'nbf': int(time.time()),
|
'nbf': int(time.time()),
|
||||||
'exp': int(time.time() + 60),
|
'iat': int(time.time()),
|
||||||
'sub': user.username if user else '(anonymous)',
|
'exp': int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
|
||||||
|
'sub': user.username if user else ANONYMOUS_SUB,
|
||||||
'access': access,
|
'access': access,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
185
test/test_registry_v2_auth.py
Normal file
185
test/test_registry_v2_auth.py
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
import unittest
|
||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
|
||||||
|
from app import app
|
||||||
|
from endpoints.v2.v2auth import (TOKEN_VALIDITY_LIFETIME_S, load_certificate_bytes,
|
||||||
|
load_private_key, ANONYMOUS_SUB)
|
||||||
|
from auth.jwt_auth import identity_from_bearer_token, load_public_key, InvalidJWTException
|
||||||
|
from util.morecollections import AttrDict
|
||||||
|
|
||||||
|
|
||||||
|
TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
|
||||||
|
TEST_USER = AttrDict({'username': 'joeuser'})
|
||||||
|
MAX_SIGNED_S = 3660
|
||||||
|
|
||||||
|
class TestRegistryV2Auth(unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TestRegistryV2Auth, self).__init__(*args, **kwargs)
|
||||||
|
self.public_key = None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH']
|
||||||
|
self.public_key = load_public_key(certificate_file_path)
|
||||||
|
|
||||||
|
def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
|
||||||
|
exp=None, nbf=None, iss=app.config['JWT_AUTH_TOKEN_ISSUER']):
|
||||||
|
return {
|
||||||
|
'iss': iss,
|
||||||
|
'aud': audience,
|
||||||
|
'nbf': nbf if nbf is not None else int(time.time()),
|
||||||
|
'iat': iat if iat is not None else int(time.time()),
|
||||||
|
'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
|
||||||
|
'sub': user.username if user else ANONYMOUS_SUB,
|
||||||
|
'access': access,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_token(self, token_data):
|
||||||
|
|
||||||
|
certificate = load_certificate_bytes(app.config['JWT_AUTH_CERTIFICATE_PATH'])
|
||||||
|
|
||||||
|
token_headers = {
|
||||||
|
'x5c': [certificate],
|
||||||
|
}
|
||||||
|
|
||||||
|
private_key = load_private_key(app.config['JWT_AUTH_PRIVATE_KEY_PATH'])
|
||||||
|
token_data = jwt.encode(token_data, private_key, 'RS256', headers=token_headers)
|
||||||
|
return 'Bearer {0}'.format(token_data)
|
||||||
|
|
||||||
|
def _parse_token(self, token):
|
||||||
|
return identity_from_bearer_token(token, MAX_SIGNED_S, self.public_key)
|
||||||
|
|
||||||
|
def _generate_public_key(self):
|
||||||
|
key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537,
|
||||||
|
key_size=1024,
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
return key.public_key()
|
||||||
|
|
||||||
|
def test_accepted_token(self):
|
||||||
|
token = self._generate_token(self._generate_token_data())
|
||||||
|
identity = self._parse_token(token)
|
||||||
|
self.assertEqual(identity.id, TEST_USER.username)
|
||||||
|
self.assertEqual(0, len(identity.provides))
|
||||||
|
|
||||||
|
anon_token = self._generate_token(self._generate_token_data(user=None))
|
||||||
|
anon_identity = self._parse_token(anon_token)
|
||||||
|
self.assertEqual(anon_identity.id, ANONYMOUS_SUB)
|
||||||
|
self.assertEqual(0, len(identity.provides))
|
||||||
|
|
||||||
|
def test_token_with_access(self):
|
||||||
|
access = [
|
||||||
|
{
|
||||||
|
'type': 'repository',
|
||||||
|
'name': 'somens/somerepo',
|
||||||
|
'actions': ['pull', 'push'],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
token = self._generate_token(self._generate_token_data(access=access))
|
||||||
|
identity = self._parse_token(token)
|
||||||
|
self.assertEqual(identity.id, TEST_USER.username)
|
||||||
|
self.assertEqual(1, len(identity.provides))
|
||||||
|
|
||||||
|
def test_malformed_access(self):
|
||||||
|
access = [
|
||||||
|
{
|
||||||
|
'toipe': 'repository',
|
||||||
|
'namesies': 'somens/somerepo',
|
||||||
|
'akshuns': ['pull', 'push'],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
token = self._generate_token(self._generate_token_data(access=access))
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
def test_bad_signature(self):
|
||||||
|
token = self._generate_token(self._generate_token_data())
|
||||||
|
other_public_key = self._generate_public_key()
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
identity_from_bearer_token(token, MAX_SIGNED_S, other_public_key)
|
||||||
|
|
||||||
|
def test_audience(self):
|
||||||
|
token_data = self._generate_token_data(audience='someotherapp')
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
token_data.pop('aud')
|
||||||
|
no_aud = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(no_aud)
|
||||||
|
|
||||||
|
def test_nbf(self):
|
||||||
|
future = int(time.time()) + 60
|
||||||
|
token_data = self._generate_token_data(nbf=future)
|
||||||
|
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
token_data.pop('nbf')
|
||||||
|
no_nbf_token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(no_nbf_token)
|
||||||
|
|
||||||
|
def test_iat(self):
|
||||||
|
future = int(time.time()) + 60
|
||||||
|
token_data = self._generate_token_data(iat=future)
|
||||||
|
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
token_data.pop('iat')
|
||||||
|
no_iat_token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(no_iat_token)
|
||||||
|
|
||||||
|
def test_exp(self):
|
||||||
|
too_far = int(time.time()) + MAX_SIGNED_S * 2
|
||||||
|
token_data = self._generate_token_data(exp=too_far)
|
||||||
|
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
past = int(time.time()) - 60
|
||||||
|
token_data['exp'] = past
|
||||||
|
expired_token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(expired_token)
|
||||||
|
|
||||||
|
token_data.pop('exp')
|
||||||
|
no_exp_token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(no_exp_token)
|
||||||
|
|
||||||
|
def test_no_sub(self):
|
||||||
|
token_data = self._generate_token_data()
|
||||||
|
token_data.pop('sub')
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
def test_iss(self):
|
||||||
|
token_data = self._generate_token_data(iss='badissuer')
|
||||||
|
|
||||||
|
token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(token)
|
||||||
|
|
||||||
|
token_data.pop('iss')
|
||||||
|
no_iss_token = self._generate_token(token_data)
|
||||||
|
with self.assertRaises(InvalidJWTException):
|
||||||
|
self._parse_token(no_iss_token)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
unittest.main()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from jwt import PyJWT
|
from jwt import PyJWT
|
||||||
from jwt.exceptions import (
|
from jwt.exceptions import (
|
||||||
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
||||||
|
@ -14,8 +15,41 @@ class StrictJWT(PyJWT):
|
||||||
'require_exp': True,
|
'require_exp': True,
|
||||||
'require_iat': True,
|
'require_iat': True,
|
||||||
'require_nbf': True,
|
'require_nbf': True,
|
||||||
|
'exp_max_s': None,
|
||||||
})
|
})
|
||||||
return defaults
|
return defaults
|
||||||
|
|
||||||
|
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
|
||||||
|
if options.get('exp_max_s') is not None:
|
||||||
|
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
|
||||||
|
raise ValueError('exp_max_s option implies verify_expiration')
|
||||||
|
|
||||||
|
options['verify_exp'] = True
|
||||||
|
|
||||||
|
# Do all of the other checks
|
||||||
|
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
||||||
|
|
||||||
|
if 'exp' in payload and options.get('exp_max_s') is not None:
|
||||||
|
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
||||||
|
# or in the absense of an issue time, more than exp_max_s in the future from now
|
||||||
|
|
||||||
|
# This will work because the parent method already checked the type of exp
|
||||||
|
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
||||||
|
max_signed_s = options.get('exp_max_s')
|
||||||
|
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
if 'iat' in payload:
|
||||||
|
start_time = datetime.utcfromtimestamp(int(payload['iat']))
|
||||||
|
|
||||||
|
if expiration > start_time + timedelta(seconds=max_signed_s):
|
||||||
|
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
|
||||||
|
start_time)
|
||||||
|
|
||||||
|
|
||||||
|
def exp_max_s_option(max_exp_s):
|
||||||
|
return {
|
||||||
|
'exp_max_s': max_exp_s,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
decode = StrictJWT().decode
|
decode = StrictJWT().decode
|
||||||
|
|
Reference in a new issue