Merge pull request #446 from coreos-inc/strictjwt

Make our JWT checking more strict.
This commit is contained in:
Jake Moshenko 2015-09-04 15:34:24 -04:00
commit 1635104280
4 changed files with 34 additions and 14 deletions

View file

@ -1,5 +1,4 @@
import logging import logging
import jwt
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -11,10 +10,11 @@ from cryptography.hazmat.backends import default_backend
from cachetools import lru_cache from cachetools import lru_cache
from app import app from app import app
from auth_context import set_grant_user_context from .auth_context import set_grant_user_context
from permissions import repository_read_grant, repository_write_grant from .permissions import repository_read_grant, repository_write_grant
from util.names import parse_namespace_repository from util.names import parse_namespace_repository
from util.http import abort from util.http import abort
from util.security import strictjwt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,17 +44,14 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key):
# Load the JWT returned. # Load the JWT returned.
try: try:
payload = jwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay', payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay',
issuer='token-issuer') issuer='token-issuer')
except jwt.InvalidTokenError: except strictjwt.InvalidTokenError:
raise InvalidJWTException('Invalid token') raise InvalidJWTException('Invalid token')
if not 'sub' in payload: if not 'sub' in payload:
raise InvalidJWTException('Missing sub field in JWT') raise InvalidJWTException('Missing sub field in JWT')
if not 'exp' in payload:
raise InvalidJWTException('Missing exp field in JWT')
# Verify that the expiration is no more than 300 seconds in the future. # Verify that the expiration is no more than 300 seconds in the future.
if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s): if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s):
raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s) raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s)

View file

@ -1,13 +1,15 @@
import logging import logging
import json import json
import os import os
import jwt
from datetime import datetime, timedelta from datetime import datetime, timedelta
from data.users.federated import FederatedUsers, VerifiedCredentials from data.users.federated import FederatedUsers, VerifiedCredentials
from util.security import strictjwt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExternalJWTAuthN(FederatedUsers): class ExternalJWTAuthN(FederatedUsers):
""" Delegates authentication to a REST endpoint that returns JWTs. """ """ Delegates authentication to a REST endpoint that returns JWTs. """
PUBLIC_KEY_FILENAME = 'jwt-authn.cert' PUBLIC_KEY_FILENAME = 'jwt-authn.cert'
@ -45,9 +47,9 @@ class ExternalJWTAuthN(FederatedUsers):
# Load the JWT returned. # Load the JWT returned.
encoded = result_data.get('token', '') encoded = result_data.get('token', '')
try: try:
payload = jwt.decode(encoded, self.public_key, algorithms=['RS256'], payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'],
audience='quay.io/jwtauthn', issuer=self.issuer) audience='quay.io/jwtauthn', issuer=self.issuer)
except jwt.InvalidTokenError: except strictjwt.InvalidTokenError:
logger.exception('Exception when decoding returned JWT') logger.exception('Exception when decoding returned JWT')
return (None, 'Invalid username or password') return (None, 'Invalid username or password')

View file

@ -57,7 +57,7 @@ pyasn1==0.1.8
pycparser==2.14 pycparser==2.14
pycrypto==2.6.1 pycrypto==2.6.1
pygpgme==0.3 pygpgme==0.3
PyJWT==1.3.0 PyJWT==1.4.0
PyMySQL==0.6.6 PyMySQL==0.6.6
pyOpenSSL==0.15.1 pyOpenSSL==0.15.1
PyPDF2==1.24 PyPDF2==1.24

View file

@ -0,0 +1,21 @@
from jwt import PyJWT
from jwt.exceptions import (
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
)
class StrictJWT(PyJWT):
@staticmethod
def _get_default_options():
# Weird syntax to call super on a staticmethod
defaults = super(StrictJWT, StrictJWT)._get_default_options()
defaults.update({
'require_exp': True,
'require_iat': True,
'require_nbf': True,
})
return defaults
decode = StrictJWT().decode