import unittest import time import jwt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa from app import app from endpoints.v2.v2auth import TOKEN_VALIDITY_LIFETIME_S from auth.registry_jwt_auth import identity_from_bearer_token, load_public_key, InvalidJWTException from util.morecollections import AttrDict from util.security.registry_jwt import (_load_certificate_bytes, _load_private_key, ANONYMOUS_SUB, build_context_and_subject) TEST_AUDIENCE = app.config['SERVER_HOSTNAME'] TEST_USER = AttrDict({'username': 'joeuser'}) MAX_SIGNED_S = 3660 class TestRegistryV2Auth(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestRegistryV2Auth, self).__init__(*args, **kwargs) self.public_key = None def setUp(self): certificate_file_path = app.config['JWT_AUTH_CERTIFICATE_PATH'] self.public_key = load_public_key(certificate_file_path) def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None, exp=None, nbf=None, iss=app.config['JWT_AUTH_TOKEN_ISSUER']): _, subject = build_context_and_subject(user, None, None) return { 'iss': iss, 'aud': audience, 'nbf': nbf if nbf is not None else int(time.time()), 'iat': iat if iat is not None else int(time.time()), 'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S), 'sub': subject, 'access': access, } def _generate_token(self, token_data): certificate = _load_certificate_bytes(app.config['JWT_AUTH_CERTIFICATE_PATH']) token_headers = { 'x5c': [certificate], } private_key = _load_private_key(app.config['JWT_AUTH_PRIVATE_KEY_PATH']) token_data = jwt.encode(token_data, private_key, 'RS256', headers=token_headers) return 'Bearer {0}'.format(token_data) def _parse_token(self, token): return identity_from_bearer_token(token, MAX_SIGNED_S, self.public_key)[0] def _generate_public_key(self): key = rsa.generate_private_key( public_exponent=65537, key_size=1024, backend=default_backend() ) return key.public_key() def test_accepted_token(self): token = self._generate_token(self._generate_token_data()) identity = self._parse_token(token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(0, len(identity.provides)) anon_token = self._generate_token(self._generate_token_data(user=None)) anon_identity = self._parse_token(anon_token) self.assertEqual(anon_identity.id, ANONYMOUS_SUB) self.assertEqual(0, len(identity.provides)) def test_token_with_access(self): access = [ { 'type': 'repository', 'name': 'somens/somerepo', 'actions': ['pull', 'push'], } ] token = self._generate_token(self._generate_token_data(access=access)) identity = self._parse_token(token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(1, len(identity.provides)) def test_malformed_access(self): access = [ { 'toipe': 'repository', 'namesies': 'somens/somerepo', 'akshuns': ['pull', 'push'], } ] token = self._generate_token(self._generate_token_data(access=access)) with self.assertRaises(InvalidJWTException): self._parse_token(token) def test_bad_signature(self): token = self._generate_token(self._generate_token_data()) other_public_key = self._generate_public_key() with self.assertRaises(InvalidJWTException): identity_from_bearer_token(token, MAX_SIGNED_S, other_public_key) def test_audience(self): token_data = self._generate_token_data(audience='someotherapp') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('aud') no_aud = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_aud) def test_nbf(self): future = int(time.time()) + 60 token_data = self._generate_token_data(nbf=future) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('nbf') no_nbf_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_nbf_token) def test_iat(self): future = int(time.time()) + 60 token_data = self._generate_token_data(iat=future) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('iat') no_iat_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_iat_token) def test_exp(self): too_far = int(time.time()) + MAX_SIGNED_S * 2 token_data = self._generate_token_data(exp=too_far) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) past = int(time.time()) - 60 token_data['exp'] = past expired_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(expired_token) token_data.pop('exp') no_exp_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_exp_token) def test_no_sub(self): token_data = self._generate_token_data() token_data.pop('sub') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) def test_iss(self): token_data = self._generate_token_data(iss='badissuer') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('iss') no_iss_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_iss_token) if __name__ == '__main__': import logging logging.basicConfig(level=logging.DEBUG) unittest.main()