import unittest
import time
import jwt

from app import app, instance_keys
from data import model
from data.database import ServiceKeyApprovalType
from initdb import setup_database_for_testing, finished_database_for_testing
from endpoints.v2.v2auth import TOKEN_VALIDITY_LIFETIME_S
from auth.registry_jwt_auth import identity_from_bearer_token, InvalidJWTException
from util.morecollections import AttrDict
from util.security.registry_jwt import (ANONYMOUS_SUB, build_context_and_subject,
                                        generate_bearer_token)


TEST_AUDIENCE = app.config['SERVER_HOSTNAME']
TEST_USER = AttrDict({'username': 'joeuser'})
MAX_SIGNED_S = 3660


class TestRegistryV2Auth(unittest.TestCase):
  def setUp(self):
    setup_database_for_testing(self)

  def tearDown(self):
    finished_database_for_testing(self)

  def _generate_token_data(self, access=[], audience=TEST_AUDIENCE, user=TEST_USER, iat=None,
                           exp=None, nbf=None, iss=None):

    _, subject = build_context_and_subject(user, None, None)
    return {
      'iss': iss or instance_keys.service_name,
      'aud': audience,
      'nbf': nbf if nbf is not None else int(time.time()),
      'iat': iat if iat is not None else int(time.time()),
      'exp': exp if exp is not None else int(time.time() + TOKEN_VALIDITY_LIFETIME_S),
      'sub': subject,
      'access': access,
    }

  def _generate_token(self, token_data, key_id=None, private_key=None, skip_header=False, alg=None):
    key_id = key_id or instance_keys.local_key_id
    private_key = private_key or instance_keys.local_private_key

    if alg == "none":
      private_key = None

    token_headers = {
      'kid': key_id,
    }

    if skip_header:
      token_headers = {}

    token_data = jwt.encode(token_data, private_key, alg or 'RS256', headers=token_headers)
    return 'Bearer {0}'.format(token_data)

  def _parse_token(self, token):
    return identity_from_bearer_token(token)[0]

  def test_accepted_token(self):
    token = self._generate_token(self._generate_token_data())
    identity = self._parse_token(token)
    self.assertEqual(identity.id, TEST_USER.username)
    self.assertEqual(0, len(identity.provides))

    anon_token = self._generate_token(self._generate_token_data(user=None))
    anon_identity = self._parse_token(anon_token)
    self.assertEqual(anon_identity.id, ANONYMOUS_SUB)
    self.assertEqual(0, len(identity.provides))

  def test_token_with_access(self):
    access = [
      {
        'type': 'repository',
        'name': 'somens/somerepo',
        'actions': ['pull', 'push'],
      }
    ]
    token = self._generate_token(self._generate_token_data(access=access))
    identity = self._parse_token(token)
    self.assertEqual(identity.id, TEST_USER.username)
    self.assertEqual(1, len(identity.provides))

  def test_malformed_access(self):
    access = [
      {
        'toipe': 'repository',
        'namesies': 'somens/somerepo',
        'akshuns': ['pull', 'push'],
      }
    ]
    token = self._generate_token(self._generate_token_data(access=access))
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

  def test_audience(self):
    token_data = self._generate_token_data(audience='someotherapp')
    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

    token_data.pop('aud')
    no_aud = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(no_aud)

  def test_nbf(self):
    future = int(time.time()) + 60
    token_data = self._generate_token_data(nbf=future)

    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

    token_data.pop('nbf')
    no_nbf_token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(no_nbf_token)

  def test_iat(self):
    future = int(time.time()) + 60
    token_data = self._generate_token_data(iat=future)

    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

    token_data.pop('iat')
    no_iat_token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(no_iat_token)

  def test_exp(self):
    too_far = int(time.time()) + MAX_SIGNED_S * 2
    token_data = self._generate_token_data(exp=too_far)

    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

    past = int(time.time()) - 60
    token_data['exp'] = past
    expired_token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(expired_token)

    token_data.pop('exp')
    no_exp_token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(no_exp_token)

  def test_no_sub(self):
    token_data = self._generate_token_data()
    token_data.pop('sub')
    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

  def test_iss(self):
    token_data = self._generate_token_data(iss='badissuer')
    token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(token)

    token_data.pop('iss')
    no_iss_token = self._generate_token(token_data)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(no_iss_token)

  def test_missing_header(self):
    token_data = self._generate_token_data()
    missing_header_token = self._generate_token(token_data, skip_header=True)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(missing_header_token)

  def test_invalid_key(self):
    token_data = self._generate_token_data()
    invalid_key_token = self._generate_token(token_data, key_id='someunknownkey')
    with self.assertRaises(InvalidJWTException):
      self._parse_token(invalid_key_token)

  def test_expired_key(self):
    token_data = self._generate_token_data()
    expired_key_token = self._generate_token(token_data, key_id='kid7')
    with self.assertRaises(InvalidJWTException):
      self._parse_token(expired_key_token)

  def test_mixing_keys(self):
    token_data = self._generate_token_data()

    # Create a new key for testing.
    p, key = model.service_keys.generate_service_key(instance_keys.service_name, None, kid='newkey',
                                                     name='newkey', metadata={})

    private_key = p.exportKey('PEM')

    # Test first with the new valid, but unapproved key.
    unapproved_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(unapproved_key_token)

    # Approve the key and try again.
    admin_user = model.user.get_user('devtable')
    model.service_keys.approve_service_key(key.kid, admin_user, ServiceKeyApprovalType.SUPERUSER)

    valid_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)

    identity = self._parse_token(valid_token)
    self.assertEqual(identity.id, TEST_USER.username)
    self.assertEqual(0, len(identity.provides))

    # Try using a different private key with the existing key ID.
    bad_private_token = self._generate_token(token_data, key_id='newkey', private_key=instance_keys.local_private_key)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(bad_private_token)

    # Try using a different key ID with the existing private key.
    kid_mismatch_token = self._generate_token(token_data, key_id=instance_keys.local_key_id, private_key=private_key)
    with self.assertRaises(InvalidJWTException):
      self._parse_token(kid_mismatch_token)

    # Delete the new key.
    key.delete_instance(recursive=True)

    # Ensure it still works (via the cache.)
    deleted_key_token = self._generate_token(token_data, key_id='newkey', private_key=private_key)
    identity = self._parse_token(deleted_key_token)
    self.assertEqual(identity.id, TEST_USER.username)
    self.assertEqual(0, len(identity.provides))

    # Break the cache.
    instance_keys.clear_cache()

    # Ensure the key no longer works.
    with self.assertRaises(InvalidJWTException):
      self._parse_token(deleted_key_token)

  def test_bad_token(self):
    with self.assertRaises(InvalidJWTException):
      self._parse_token("some random token here")

  def test_bad_bearer_token(self):
    with self.assertRaises(InvalidJWTException):
      self._parse_token("Bearer: sometokenhere")

  def test_bad_bearer_newline_token(self):
    with self.assertRaises(InvalidJWTException):
      self._parse_token("\nBearer: dGVzdA")

  def test_ensure_no_none(self):
    token_data = self._generate_token_data()
    none_token = self._generate_token(token_data, alg='none', private_key=None)

    with self.assertRaises(InvalidJWTException):
      self._parse_token(none_token)


if __name__ == '__main__':
  unittest.main()