import unittest

from datetime import datetime, timedelta

import jwt
import json

from Crypto.PublicKey import RSA
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key

from util.license import (decode_license, LicenseDecodeError, ExpirationType,
                          MONTHLY_GRACE_PERIOD, YEARLY_GRACE_PERIOD, TRIAL_GRACE_PERIOD,
                          QUAY_DEPLOYMENTS_ENTITLEMENT, QUAY_ENTITLEMENT)


def get_date(delta):
  return str(datetime.now() + delta)

class TestLicense(unittest.TestCase):
  def keys(self):
    with open('test/data/test.pem') as f:
      private_key = f.read()

    public_key = load_der_public_key(RSA.importKey(private_key).publickey().exportKey('DER'),
                                     backend=default_backend())
    return (public_key, private_key)

  def create_license(self, license_data, keys=None):
    jwt_data = {
      'license': json.dumps(license_data),
    }

    (public_key, private_key) = keys or self.keys()

    # Encode the license with the JWT key.
    encoded = jwt.encode(jwt_data, private_key, algorithm='RS256')

    # Decode it into a license object.
    return decode_license(encoded, public_key_instance=public_key)

  def test_license_decodeerror_invalid(self):
    with self.assertRaises(LicenseDecodeError):
      decode_license('some random stuff')

  def test_license_decodeerror_badkey(self):
    (_, private_key) = self.keys()
    jwt_data = {
      'license': json.dumps({}),
    }

    encoded_stuff = jwt.encode(jwt_data, private_key, algorithm='RS256')
    with self.assertRaises(LicenseDecodeError):
      # Note that since we don't give a key here, the prod one will be used, and it should fail.
      decode_license(encoded_stuff)

  def assertValid(self, license, config=None):
    results = license.validate(config or {})
    is_met = all([r.is_met() for r in results])
    self.assertTrue(is_met, [r for r in results if not r.is_met()])

  def assertNotValid(self, license, config=None, requirement=None, expired=None):
    results = license.validate(config or {})
    is_met = all([r.is_met() for r in results])
    self.assertFalse(is_met)

    invalid_results = [r for r in results if not r.is_met()]
    if requirement is not None:
      self.assertEquals(invalid_results[0].requirement.name, requirement)

    if expired is not None:
      self.assertEquals(invalid_results[0].entitlement.expiration.expiration_type, expired)

  def test_missing_subscriptions(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
    })

    self.assertNotValid(license, requirement=QUAY_ENTITLEMENT)

  def test_empty_subscriptions(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {},
    })

    self.assertNotValid(license, requirement=QUAY_ENTITLEMENT)

  def test_missing_quay_entitlement(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_DEPLOYMENTS_ENTITLEMENT: 0,
          },
        },
      },
    })

    self.assertNotValid(license, requirement=QUAY_ENTITLEMENT)

  def test_valid_quay_entitlement(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_missing_expiration(self):
    license = self.create_license({
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.license_wide)

  def test_expired_license(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=-10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.license_wide)

  def test_expired_sub_implicit_monthly_withingrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(MONTHLY_GRACE_PERIOD * -1 + timedelta(days=1)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_expired_sub_monthly_withingrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(MONTHLY_GRACE_PERIOD * -1 + timedelta(days=1)),
          "durationPeriod": "monthly",
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_expired_sub_monthly_outsidegrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(MONTHLY_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "durationPeriod": "monthly",
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.monthly)

  def test_expired_sub_yearly_withingrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(YEARLY_GRACE_PERIOD * -1 + timedelta(days=1)),
          "durationPeriod": "yearly",
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_expired_sub_yearly_outsidegrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(YEARLY_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "durationPeriod": "yearly",
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.yearly)

  def test_expired_sub_intrial_withingrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=1)),
          "inTrial": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_expired_sub_intrial_outsidegrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "inTrial": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.in_trial)

  def test_expired_sub_trialonly_withingrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "trialEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=1)),
          "trialOnly": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertValid(license)

  def test_expired_sub_trialonly_outsidegrace(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "trialEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "trialOnly": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    self.assertNotValid(license, expired=ExpirationType.trial_only)

  def test_valid_quay_entitlement_regions(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
      ],
    }

    self.assertValid(license, config=config)

  def test_invalid_quay_entitlement_regions(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertNotValid(license, config=config, requirement=QUAY_DEPLOYMENTS_ENTITLEMENT)

  def test_valid_regions_across_multiple_sub(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
        "anothersub": {
          "serviceEnd": get_date(timedelta(days=20)),
          "entitlements": {
            QUAY_DEPLOYMENTS_ENTITLEMENT: 5,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertValid(license, config=config)

  def test_valid_regions_across_multiple_sub_one_expired(self):
    # Setup a license with one sub having too few regions, and another having enough, but it is
    # expired.
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "serviceEnd": get_date(timedelta(days=10)),
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 1,
          },
        },
        "anothersub": {
          "trialEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "trialOnly": True,
          "entitlements": {
            QUAY_DEPLOYMENTS_ENTITLEMENT: 5,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertNotValid(license, config=config, requirement=QUAY_DEPLOYMENTS_ENTITLEMENT,
                        expired=ExpirationType.trial_only)

  def test_valid_regions_across_multiple_sub_one_expired(self):
    service_end = get_date(timedelta(days=20))
    expiration_date = get_date(timedelta(days=10))

    license = self.create_license({
      "expirationDate": expiration_date,
      "subscriptions": {
        "somesub": {
          "trialEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "trialOnly": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 3,
          },
        },
        "anothersub": {
          "serviceEnd": service_end,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 5,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertValid(license, config=config)

    entitlements = license.validate(config)
    self.assertEquals(2, len(entitlements))

    self.assertEntitlement(entitlements[0], QUAY_ENTITLEMENT, expiration_date)
    self.assertEntitlement(entitlements[1], QUAY_DEPLOYMENTS_ENTITLEMENT, expiration_date)

  def test_quay_is_under_expired_sub(self):
    license = self.create_license({
      "expirationDate": get_date(timedelta(days=10)),
      "subscriptions": {
        "somesub": {
          "trialEnd": get_date(TRIAL_GRACE_PERIOD * -1 + timedelta(days=-1)),
          "trialOnly": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 3,
          },
        },
        "anothersub": {
          "serviceEnd": get_date(timedelta(days=20)),
          "entitlements": {
            QUAY_DEPLOYMENTS_ENTITLEMENT: 5,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertNotValid(license, config=config, expired=ExpirationType.trial_only,
                        requirement=QUAY_ENTITLEMENT)

  def assertEntitlement(self, entitlement, expected_name, expected_date):
    self.assertEquals(expected_name, entitlement.requirement.name)
    self.assertEquals(expected_date, str(entitlement.entitlement.expiration.expiration_date))

  def test_license_with_multiple_subscriptions(self):
    service_end = get_date(timedelta(days=20))
    expiration_date = get_date(timedelta(days=10))
    trial_end = get_date(timedelta(days=2))

    license = self.create_license({
      "expirationDate": expiration_date,
      "subscriptions": {
        "realsub": {
          "serviceEnd": service_end,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
          },
        },
        "trialsub": {
          "trialEnd": trial_end,
          "trialOnly": True,
          "inTrial": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 3,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertValid(license, config=config)

    entitlements = license.validate(config)
    self.assertEquals(2, len(entitlements))

    self.assertEntitlement(entitlements[0], QUAY_ENTITLEMENT, expiration_date)
    self.assertEntitlement(entitlements[1], QUAY_DEPLOYMENTS_ENTITLEMENT, trial_end)

  def test_license_with_multiple_subscriptions_one_expired(self):
    service_end = get_date(timedelta(days=20))
    expiration_date = get_date(timedelta(days=10))
    trial_end = get_date(timedelta(days=-2))

    license = self.create_license({
      "expirationDate": expiration_date,
      "subscriptions": {
        "realsub": {
          "serviceEnd": service_end,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 3,
          },
        },
        "trialsub": {
          "trialEnd": trial_end,
          "trialOnly": True,
          "inTrial": True,
          "entitlements": {
            QUAY_ENTITLEMENT: 1,
            QUAY_DEPLOYMENTS_ENTITLEMENT: 3,
          },
        },
      },
    })

    config = {
      'DISTRIBUTED_STORAGE_CONFIG': [
        {'name': 'first'},
        {'name': 'second'},
      ],
    }

    self.assertValid(license, config=config)

    entitlements = license.validate(config)
    self.assertEquals(2, len(entitlements))

    self.assertEntitlement(entitlements[0], QUAY_ENTITLEMENT, expiration_date)
    self.assertEntitlement(entitlements[1], QUAY_DEPLOYMENTS_ENTITLEMENT, expiration_date)

if __name__ == '__main__':
  unittest.main()