From 82efc746b3ccf5bf133b9dd993a5e46016ebbc0f Mon Sep 17 00:00:00 2001 From: Jake Moshenko Date: Fri, 4 Sep 2015 11:29:22 -0400 Subject: [PATCH] Make our JWT checking more strict. --- auth/jwt_auth.py | 15 ++++++--------- data/users/externaljwt.py | 10 ++++++---- requirements.txt | 2 +- util/security/strictjwt.py | 21 +++++++++++++++++++++ 4 files changed, 34 insertions(+), 14 deletions(-) create mode 100644 util/security/strictjwt.py diff --git a/auth/jwt_auth.py b/auth/jwt_auth.py index cd1a6ca31..9a4aa1bbe 100644 --- a/auth/jwt_auth.py +++ b/auth/jwt_auth.py @@ -1,5 +1,4 @@ import logging -import jwt import re from datetime import datetime, timedelta @@ -11,10 +10,11 @@ from cryptography.hazmat.backends import default_backend from cachetools import lru_cache from app import app -from auth_context import set_grant_user_context -from permissions import repository_read_grant, repository_write_grant +from .auth_context import set_grant_user_context +from .permissions import repository_read_grant, repository_write_grant from util.names import parse_namespace_repository from util.http import abort +from util.security import strictjwt logger = logging.getLogger(__name__) @@ -44,17 +44,14 @@ def identity_from_bearer_token(bearer_token, max_signed_s, public_key): # Load the JWT returned. try: - payload = jwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay', - issuer='token-issuer') - except jwt.InvalidTokenError: + payload = strictjwt.decode(encoded, public_key, algorithms=['RS256'], audience='quay', + issuer='token-issuer') + except strictjwt.InvalidTokenError: raise InvalidJWTException('Invalid token') if not 'sub' in payload: raise InvalidJWTException('Missing sub field in JWT') - if not 'exp' in payload: - raise InvalidJWTException('Missing exp field in JWT') - # Verify that the expiration is no more than 300 seconds in the future. if datetime.fromtimestamp(payload['exp']) > datetime.utcnow() + timedelta(seconds=max_signed_s): raise InvalidJWTException('Token was signed for more than %s seconds' % max_signed_s) diff --git a/data/users/externaljwt.py b/data/users/externaljwt.py index 241cfa947..ac29f22a1 100644 --- a/data/users/externaljwt.py +++ b/data/users/externaljwt.py @@ -1,13 +1,15 @@ import logging import json import os -import jwt from datetime import datetime, timedelta from data.users.federated import FederatedUsers, VerifiedCredentials +from util.security import strictjwt + logger = logging.getLogger(__name__) + class ExternalJWTAuthN(FederatedUsers): """ Delegates authentication to a REST endpoint that returns JWTs. """ PUBLIC_KEY_FILENAME = 'jwt-authn.cert' @@ -45,9 +47,9 @@ class ExternalJWTAuthN(FederatedUsers): # Load the JWT returned. encoded = result_data.get('token', '') try: - payload = jwt.decode(encoded, self.public_key, algorithms=['RS256'], - audience='quay.io/jwtauthn', issuer=self.issuer) - except jwt.InvalidTokenError: + payload = strictjwt.decode(encoded, self.public_key, algorithms=['RS256'], + audience='quay.io/jwtauthn', issuer=self.issuer) + except strictjwt.InvalidTokenError: logger.exception('Exception when decoding returned JWT') return (None, 'Invalid username or password') diff --git a/requirements.txt b/requirements.txt index 40828e1f5..ba6dcbb2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,7 +57,7 @@ pyasn1==0.1.8 pycparser==2.14 pycrypto==2.6.1 pygpgme==0.3 -PyJWT==1.3.0 +PyJWT==1.4.0 PyMySQL==0.6.6 pyOpenSSL==0.15.1 PyPDF2==1.24 diff --git a/util/security/strictjwt.py b/util/security/strictjwt.py new file mode 100644 index 000000000..35f94444c --- /dev/null +++ b/util/security/strictjwt.py @@ -0,0 +1,21 @@ +from jwt import PyJWT +from jwt.exceptions import ( + InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError, + ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError +) + + +class StrictJWT(PyJWT): + @staticmethod + def _get_default_options(): + # Weird syntax to call super on a staticmethod + defaults = super(StrictJWT, StrictJWT)._get_default_options() + defaults.update({ + 'require_exp': True, + 'require_iat': True, + 'require_nbf': True, + }) + return defaults + + +decode = StrictJWT().decode