import unittest import time import jwt from app import app, instance_keys from data import model from data.database import ServiceKeyApprovalType from initdb import setup_database_for_testing, finished_database_for_testing from endpoints.v2.v2auth import TOKEN_VALIDITY_LIFETIME_S from auth.registry_jwt_auth import identity_from_bearer_token, InvalidJWTException from util.morecollections import AttrDict from util.security.registry_jwt import (ANONYMOUS_SUB, build_context_and_subject, decode_bearer_token, generate_bearer_token) TEST_AUDIENCE = app.config['SERVER_HOSTNAME'] TEST_USER = AttrDict({'username': 'joeuser'}) MAX_SIGNED_S = 3660 class TestRegistryV2Auth(unittest.TestCase): def setUp(self): setup_database_for_testing(self) def tearDown(self): finished_database_for_testing(self) def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None, exp=None, nbf=None, iss=None): _, subject = build_context_and_subject(user, None, None) return { 'iss': iss or instance_keys.service_name, 'aud': audience, 'nbf': nbf if nbf is not None else int(time.time()), 'iat': iat if iat is not None else int(time.time()), 'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S), 'sub': subject, 'access': access, } def _generate_token(self, token_data, key_id=None, private_key=None, skip_header=False, alg=None): key_id = key_id or instance_keys.local_key_id private_key = private_key or instance_keys.local_private_key if alg == "none": private_key = None token_headers = { 'kid': key_id, } if skip_header: token_headers = {} token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers) return 'Bearer {0}'.format(token_data) def _parse_token(self, token): return identity_from_bearer_token(token)[0] def test_accepted_token(self): token = self._generate_token(self._generate_token_data()) identity = self._parse_token(token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(0, len(identity.provides)) anon_token = self._generate_token(self._generate_token_data(user=None)) anon_identity = self._parse_token(anon_token) self.assertEqual(anon_identity.id, ANONYMOUS_SUB) self.assertEqual(0, len(identity.provides)) def test_token_with_access(self): access = [ { 'type': 'repository', 'name': 'somens/somerepo', 'actions': ['pull', 'push'], } ] token = self._generate_token(self._generate_token_data(access=access)) identity = self._parse_token(token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(1, len(identity.provides)) def test_malformed_access(self): access = [ { 'toipe': 'repository', 'namesies': 'somens/somerepo', 'akshuns': ['pull', 'push'], } ] token = self._generate_token(self._generate_token_data(access=access)) with self.assertRaises(InvalidJWTException): self._parse_token(token) def test_audience(self): token_data = self._generate_token_data(audience='someotherapp') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('aud') no_aud = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_aud) def test_nbf(self): future = int(time.time()) + 60 token_data = self._generate_token_data(nbf=future) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('nbf') no_nbf_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_nbf_token) def test_iat(self): future = int(time.time()) + 60 token_data = self._generate_token_data(iat=future) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('iat') no_iat_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_iat_token) def test_exp(self): too_far = int(time.time()) + MAX_SIGNED_S * 2 token_data = self._generate_token_data(exp=too_far) token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) past = int(time.time()) - 60 token_data['exp'] = past expired_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(expired_token) token_data.pop('exp') no_exp_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_exp_token) def test_no_sub(self): token_data = self._generate_token_data() token_data.pop('sub') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) def test_iss(self): token_data = self._generate_token_data(iss='badissuer') token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(token) token_data.pop('iss') no_iss_token = self._generate_token(token_data) with self.assertRaises(InvalidJWTException): self._parse_token(no_iss_token) def test_missing_header(self): token_data = self._generate_token_data() missing_header_token = self._generate_token(token_data, skip_header=True) with self.assertRaises(InvalidJWTException): self._parse_token(missing_header_token) def test_invalid_key(self): token_data = self._generate_token_data() invalid_key_token = self._generate_token(token_data, key_id='someunknownkey') with self.assertRaises(InvalidJWTException): self._parse_token(invalid_key_token) def test_expired_key(self): token_data = self._generate_token_data() expired_key_token = self._generate_token(token_data, key_id='kid7') with self.assertRaises(InvalidJWTException): self._parse_token(expired_key_token) def test_mixing_keys(self): token_data = self._generate_token_data() # Create a new key for testing. p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey', name='newkey', metadata={}) private_key = p.exportKey('PEM') # Test first with the new valid, but unapproved key. unapproved_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key) with self.assertRaises(InvalidJWTException): self._parse_token(unapproved_key_token) # Approve the key and try again. admin_user = model.user.get_user('devtable') model.service_keys.approve_service_key(key.kid, admin_user, ServiceKeyApprovalType.SUPERUSER) valid_token = self._generate_token(token_data, key_id='newkey', private_key=private_key) identity = self._parse_token(valid_token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(0, len(identity.provides)) # Try using a different private key with the existing key ID. bad_private_token = self._generate_token(token_data, key_id='newkey', private_key=instance_keys.local_private_key) with self.assertRaises(InvalidJWTException): self._parse_token(bad_private_token) # Try using a different key ID with the existing private key. kid_mismatch_token = self._generate_token(token_data, key_id=instance_keys.local_key_id, private_key=private_key) with self.assertRaises(InvalidJWTException): self._parse_token(kid_mismatch_token) # Delete the new key. key.delete_instance(recursive=True) # Ensure it still works (via the cache.) deleted_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key) identity = self._parse_token(deleted_key_token) self.assertEqual(identity.id, TEST_USER.username) self.assertEqual(0, len(identity.provides)) # Break the cache. instance_keys.clear_cache() # Ensure the key no longer works. with self.assertRaises(InvalidJWTException): self._parse_token(deleted_key_token) def test_bad_token(self): with self.assertRaises(InvalidJWTException): self._parse_token("some random token here") def test_bad_bearer_token(self): with self.assertRaises(InvalidJWTException): self._parse_token("Bearer: sometokenhere") def test_bad_bearer_newline_token(self): with self.assertRaises(InvalidJWTException): self._parse_token("\nBearer: dGVzdA") def test_ensure_no_none(self): token_data = self._generate_token_data() none_token = self._generate_token(token_data, alg='none', private_key=None) with self.assertRaises(InvalidJWTException): self._parse_token(none_token) if __name__ == '__main__': unittest.main()