301 lines
		
	
	
	
		
			9.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			301 lines
		
	
	
	
		
			9.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import unittest
 | |
| import time
 | |
| import jwt
 | |
| 
 | |
| from app import app, instance_keys
 | |
| from data import model
 | |
| from data.database import ServiceKeyApprovalType
 | |
| from initdb import setup_database_for_testing, finished_database_for_testing
 | |
| from endpoints.v2.v2auth import TOKEN_VALIDITY_LIFETIME_S
 | |
| from auth.registry_jwt_auth import identity_from_bearer_token, InvalidJWTException
 | |
| from util.morecollections import AttrDict
 | |
| from util.security.registry_jwt import (ANONYMOUS_SUB, build_context_and_subject,
 | |
|                                         generate_bearer_token)
 | |
| 
 | |
| 
 | |
| TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
 | |
| TEST_USER = AttrDict({'username': 'joeuser'})
 | |
| MAX_SIGNED_S = 3660
 | |
| 
 | |
| 
 | |
| class TestRegistryV2Auth(unittest.TestCase):
 | |
|   def setUp(self):
 | |
|     setup_database_for_testing(self)
 | |
| 
 | |
|   def tearDown(self):
 | |
|     finished_database_for_testing(self)
 | |
| 
 | |
|   def _generate_token_data(self, access=[], context=None, audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
 | |
|                            exp=None, nbf=None, iss=None):
 | |
| 
 | |
|     _, subject = build_context_and_subject(user=user)
 | |
|     return {
 | |
|       'iss': iss or instance_keys.service_name,
 | |
|       'aud': audience,
 | |
|       'nbf': nbf if nbf is not None else int(time.time()),
 | |
|       'iat': iat if iat is not None else int(time.time()),
 | |
|       'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
 | |
|       'sub': subject,
 | |
|       'access': access,
 | |
|       'context': context,
 | |
|     }
 | |
| 
 | |
|   def _generate_token(self, token_data, key_id=None, private_key=None, skip_header=False, alg=None):
 | |
|     key_id = key_id or instance_keys.local_key_id
 | |
|     private_key = private_key or instance_keys.local_private_key
 | |
| 
 | |
|     if alg == "none":
 | |
|       private_key = None
 | |
| 
 | |
|     token_headers = {
 | |
|       'kid': key_id,
 | |
|     }
 | |
| 
 | |
|     if skip_header:
 | |
|       token_headers = {}
 | |
| 
 | |
|     token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers)
 | |
|     return 'Bearer {0}'.format(token_data)
 | |
| 
 | |
|   def _parse_token(self, token):
 | |
|     return identity_from_bearer_token(token)[0]
 | |
| 
 | |
|   def test_accepted_token(self):
 | |
|     token = self._generate_token(self._generate_token_data())
 | |
|     identity = self._parse_token(token)
 | |
|     self.assertEqual(identity.id, TEST_USER.username)
 | |
|     self.assertEqual(0, len(identity.provides))
 | |
| 
 | |
|     anon_token = self._generate_token(self._generate_token_data(user=None))
 | |
|     anon_identity = self._parse_token(anon_token)
 | |
|     self.assertEqual(anon_identity.id, ANONYMOUS_SUB)
 | |
|     self.assertEqual(0, len(identity.provides))
 | |
| 
 | |
|   def test_token_with_access(self):
 | |
|     valid_access = [
 | |
|       [
 | |
|         {
 | |
|           'type': 'repository',
 | |
|           'name': 'somens/somerepo',
 | |
|           'actions': ['pull', 'push'],
 | |
|         }
 | |
|       ],
 | |
|       [
 | |
|         {
 | |
|           'type': 'repository',
 | |
|           'name': 'somens/somerepo',
 | |
|           'actions': ['pull', '*'],
 | |
|         }
 | |
|       ],
 | |
|       [
 | |
|         {
 | |
|           'type': 'repository',
 | |
|           'name': 'somens/somerepo',
 | |
|           'actions': ['*', 'push'],
 | |
|         }
 | |
|       ],
 | |
|       [
 | |
|         {
 | |
|           'type': 'repository',
 | |
|           'name': 'somens/somerepo',
 | |
|           'actions': ['*'],
 | |
|         }
 | |
|       ],
 | |
|       [
 | |
|         {
 | |
|           'type': 'repository',
 | |
|           'name': 'somens/somerepo',
 | |
|           'actions': ['pull', '*', 'push'],
 | |
|         }
 | |
|       ],
 | |
|     ]
 | |
|     for access in valid_access:
 | |
|       token = self._generate_token(self._generate_token_data(access=access))
 | |
|       identity = self._parse_token(token)
 | |
|       self.assertEqual(identity.id, TEST_USER.username)
 | |
|       self.assertEqual(1, len(identity.provides))
 | |
|       role = list(identity.provides)[0][3]
 | |
|       if "*" in access[0]['actions']:
 | |
|         self.assertEqual(role, 'admin')
 | |
|       elif "push" in access[0]['actions']:
 | |
|         self.assertEqual(role, 'write')
 | |
|       elif "pull" in access[0]['actions']:
 | |
|         self.assertEqual(role, 'read')
 | |
|     
 | |
|   def test_malformed_access(self):
 | |
|     access = [
 | |
|       {
 | |
|         'toipe': 'repository',
 | |
|         'namesies': 'somens/somerepo',
 | |
|         'akshuns': ['pull', 'push', '*'],
 | |
|       }
 | |
|     ]
 | |
|     token = self._generate_token(self._generate_token_data(access=access))
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|   def test_audience(self):
 | |
|     token_data = self._generate_token_data(audience='someotherapp')
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|     token_data.pop('aud')
 | |
|     no_aud = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(no_aud)
 | |
| 
 | |
|   def test_nbf(self):
 | |
|     future = int(time.time()) + 60
 | |
|     token_data = self._generate_token_data(nbf=future)
 | |
| 
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|     token_data.pop('nbf')
 | |
|     no_nbf_token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(no_nbf_token)
 | |
| 
 | |
|   def test_iat(self):
 | |
|     future = int(time.time()) + 60
 | |
|     token_data = self._generate_token_data(iat=future)
 | |
| 
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|     token_data.pop('iat')
 | |
|     no_iat_token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(no_iat_token)
 | |
| 
 | |
|   def test_exp(self):
 | |
|     too_far = int(time.time()) + MAX_SIGNED_S * 2
 | |
|     token_data = self._generate_token_data(exp=too_far)
 | |
| 
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|     past = int(time.time()) - 60
 | |
|     token_data['exp'] = past
 | |
|     expired_token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(expired_token)
 | |
| 
 | |
|     token_data.pop('exp')
 | |
|     no_exp_token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(no_exp_token)
 | |
| 
 | |
|   def test_no_sub(self):
 | |
|     token_data = self._generate_token_data()
 | |
|     token_data.pop('sub')
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|   def test_iss(self):
 | |
|     token_data = self._generate_token_data(iss='badissuer')
 | |
|     token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(token)
 | |
| 
 | |
|     token_data.pop('iss')
 | |
|     no_iss_token = self._generate_token(token_data)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(no_iss_token)
 | |
| 
 | |
|   def test_missing_header(self):
 | |
|     token_data = self._generate_token_data()
 | |
|     missing_header_token = self._generate_token(token_data, skip_header=True)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(missing_header_token)
 | |
| 
 | |
|   def test_invalid_key(self):
 | |
|     token_data = self._generate_token_data()
 | |
|     invalid_key_token = self._generate_token(token_data, key_id='someunknownkey')
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(invalid_key_token)
 | |
| 
 | |
|   def test_expired_key(self):
 | |
|     token_data = self._generate_token_data()
 | |
|     expired_key_token = self._generate_token(token_data, key_id='kid7')
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(expired_key_token)
 | |
| 
 | |
|   def test_mixing_keys(self):
 | |
|     token_data = self._generate_token_data()
 | |
| 
 | |
|     # Create a new key for testing.
 | |
|     p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey',
 | |
|                                                      name='newkey', metadata={})
 | |
| 
 | |
|     private_key = p.exportKey('PEM')
 | |
| 
 | |
|     # Test first with the new valid, but unapproved key.
 | |
|     unapproved_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(unapproved_key_token)
 | |
| 
 | |
|     # Approve the key and try again.
 | |
|     admin_user = model.user.get_user('devtable')
 | |
|     model.service_keys.approve_service_key(key.kid, admin_user, ServiceKeyApprovalType.SUPERUSER)
 | |
| 
 | |
|     valid_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)
 | |
| 
 | |
|     identity = self._parse_token(valid_token)
 | |
|     self.assertEqual(identity.id, TEST_USER.username)
 | |
|     self.assertEqual(0, len(identity.provides))
 | |
| 
 | |
|     # Try using a different private key with the existing key ID.
 | |
|     bad_private_token = self._generate_token(token_data, key_id='newkey', private_key=instance_keys.local_private_key)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(bad_private_token)
 | |
| 
 | |
|     # Try using a different key ID with the existing private key.
 | |
|     kid_mismatch_token = self._generate_token(token_data, key_id=instance_keys.local_key_id, private_key=private_key)
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(kid_mismatch_token)
 | |
| 
 | |
|     # Delete the new key.
 | |
|     key.delete_instance(recursive=True)
 | |
| 
 | |
|     # Ensure it still works (via the cache.)
 | |
|     deleted_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)
 | |
|     identity = self._parse_token(deleted_key_token)
 | |
|     self.assertEqual(identity.id, TEST_USER.username)
 | |
|     self.assertEqual(0, len(identity.provides))
 | |
| 
 | |
|     # Break the cache.
 | |
|     instance_keys.clear_cache()
 | |
| 
 | |
|     # Ensure the key no longer works.
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(deleted_key_token)
 | |
| 
 | |
|   def test_bad_token(self):
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token("some random token here")
 | |
| 
 | |
|   def test_bad_bearer_token(self):
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token("Bearer: sometokenhere")
 | |
| 
 | |
|   def test_bad_bearer_newline_token(self):
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token("\nBearer: dGVzdA")
 | |
| 
 | |
|   def test_ensure_no_none(self):
 | |
|     token_data = self._generate_token_data()
 | |
|     none_token = self._generate_token(token_data, alg='none', private_key=None)
 | |
| 
 | |
|     with self.assertRaises(InvalidJWTException):
 | |
|       self._parse_token(none_token)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   unittest.main()
 | |
| 
 |