import unittest from datetime import datetime, timedelta import jwt from Crypto.PublicKey import RSA from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_public_key from util.config.provider.license import (decode_license, LICENSE_PRODUCT_NAME, LicenseValidationError) class TestLicense(unittest.TestCase): def keys(self): with open('test/data/test.pem') as f: private_key = f.read() public_key = load_der_public_key(RSA.importKey(private_key).publickey().exportKey('DER'), backend=default_backend()) return (public_key, private_key) def create_license(self, license_data): (public_key, private_key) = self.keys() # Encode the license with the JWT key. encoded = jwt.encode(license_data, private_key, algorithm='RS256') # Decode it into a license object. return decode_license(encoded, public_key_instance=public_key) def get_license(self, expiration_delta=None, **kwargs): license_data = { 'expirationDate': str(datetime.now() + expiration_delta), } if kwargs: sub = { 'productName': LICENSE_PRODUCT_NAME, } sub['trialOnly'] = kwargs.get('trial_only', False) sub['inTrial'] = kwargs.get('in_trial', False) sub['entitlements'] = kwargs.get('entitlements', []) if 'trial_end' in kwargs: sub['trialEnd'] = str(datetime.now() + kwargs['trial_end']) if 'service_end' in kwargs: sub['serviceEnd'] = str(datetime.now() + kwargs['service_end']) if 'duration' in kwargs: sub['durationPeriod'] = kwargs['duration'] license_data['subscriptions'] = [sub] decoded_license = self.create_license(license_data) return decoded_license def test_license_itself_expired(self): # License is expired. license = self.get_license(timedelta(days=-30)) def test_no_qe_subscription(self): # License is not expired, but there is no QE sub, so not valid. license = self.get_license(timedelta(days=30)) def test_trial_withingrace(self): license = self.get_license(timedelta(days=30), trial_only=True, trial_end=timedelta(days=-1)) self.assertFalse(license.is_expired) def test_trial_outsidegrace(self): license = self.get_license(timedelta(days=30), trial_only=True, trial_end=timedelta(days=-10)) self.assertTrue(license.is_expired) def test_trial_intrial_withingrace(self): license = self.get_license(timedelta(days=30), in_trial=True, service_end=timedelta(days=-1)) self.assertFalse(license.is_expired) def test_trial_intrial_outsidegrace(self): license = self.get_license(timedelta(days=30), in_trial=True, service_end=timedelta(days=-10)) self.assertTrue(license.is_expired) def test_monthly_license_valid(self): license = self.get_license(timedelta(days=30), service_end=timedelta(days=10), duration='monthly') self.assertFalse(license.is_expired) def test_monthly_license_withingrace(self): license = self.get_license(timedelta(days=30), service_end=timedelta(days=-10), duration='monthly') self.assertFalse(license.is_expired) def test_monthly_license_outsidegrace(self): license = self.get_license(timedelta(days=30), service_end=timedelta(days=-40), duration='monthly') self.assertTrue(license.is_expired) def test_yearly_license_withingrace(self): license = self.get_license(timedelta(days=30), service_end=timedelta(days=-40), duration='years') self.assertFalse(license.is_expired) def test_yearly_license_outsidegrace(self): license = self.get_license(timedelta(days=30), service_end=timedelta(days=-100), duration='years') self.assertTrue(license.is_expired) def test_valid_license(self): license = self.get_license(timedelta(days=300), service_end=timedelta(days=40), duration='years') self.assertFalse(license.is_expired) def test_validate_basic_license(self): decoded = self.get_license(timedelta(days=30), entitlements={}) decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}]}) def test_validate_storage_entitlement_valid(self): decoded = self.get_license(timedelta(days=30), entitlements={ 'software.quay.regions': 2, }) decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}]}) def test_validate_storage_entitlement_invalid(self): decoded = self.get_license(timedelta(days=30), entitlements={ 'software.quay.regions': 1, }) with self.assertRaises(LicenseValidationError): decoded.validate({'DISTRIBUTED_STORAGE_CONFIG': [{}, {}]}) if __name__ == '__main__': unittest.main()