import re

from calendar import timegm
from datetime import datetime, timedelta
from jwt import PyJWT
from jwt.exceptions import (
  InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
  ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
  InvalidAlgorithmError
)

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from jwkest.jwk import keyrep, RSAKey, ECKey


# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')

# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
# `none` here!
ALGORITHM_WHITELIST = [
  'rs256'
]

class _StrictJWT(PyJWT):
  """ _StrictJWT defines a JWT decoder with extra checks. """

  @staticmethod
  def _get_default_options():
    # Weird syntax to call super on a staticmethod
    defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
    defaults.update({
      'require_exp': True,
      'require_iat': True,
      'require_nbf': True,
      'exp_max_s': None,
    })
    return defaults

  def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
    if options.get('exp_max_s') is not None:
      if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
        raise ValueError('exp_max_s option implies verify_expiration')

      options['verify_exp'] = True

    # Do all of the other checks
    super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)

    now = timegm(datetime.utcnow().utctimetuple())
    self._reject_future_iat(payload, now, leeway)

    if 'exp' in payload and options.get('exp_max_s') is not None:
      # Validate that the expiration was not more than exp_max_s seconds after the issue time
      # or in the absence of an issue time, more than exp_max_s in the future from now

      # This will work because the parent method already checked the type of exp
      expiration = datetime.utcfromtimestamp(int(payload['exp']))
      max_signed_s = options.get('exp_max_s')

      start_time = datetime.utcnow()
      if 'iat' in payload:
        start_time = datetime.utcfromtimestamp(int(payload['iat']))

      if expiration > start_time + timedelta(seconds=max_signed_s):
        raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
                                start_time)

  def _reject_future_iat(self, payload, now, leeway):
    try:
      iat = int(payload['iat'])
    except ValueError:
      raise DecodeError('Issued At claim (iat) must be an integer.')

    if iat > (now + leeway):
      raise InvalidIssuedAtError('Issued At claim (iat) cannot be in'
                                 ' the future.')


def decode(jwt, key='', verify=True, algorithms=None, options=None,
           **kwargs):
  """ Decodes a JWT. """
  if not algorithms:
    raise InvalidAlgorithmError('algorithms must be specified')

  normalized = set([a.lower() for a in algorithms])
  if 'none' in normalized:
    raise InvalidAlgorithmError('`none` algorithm is not allowed')

  if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
    raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
                                (algorithms, ALGORITHM_WHITELIST))

  return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)


def exp_max_s_option(max_exp_s):
  """ Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
  return {
    'exp_max_s': max_exp_s,
  }


def jwk_dict_to_public_key(jwk):
  """ Converts the specified JWK into a public key. """
  jwkest_key = keyrep(jwk)
  if isinstance(jwkest_key, RSAKey):
    pycrypto_key = jwkest_key.key
    return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
  elif isinstance(jwkest_key, ECKey):
    x, y = jwkest_key.get_key()
    return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())

  raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))