import time
import jwt
import logging

from util.security import jwtutil

logger = logging.getLogger(__name__)

ANONYMOUS_SUB = '(anonymous)'
ALGORITHM = 'RS256'

# The number of allowed seconds of clock skew for a JWT. The iat, nbf and exp are adjusted with this
# count.
JWT_CLOCK_SKEW_SECONDS = 30


class InvalidBearerTokenException(Exception):
  pass


def decode_bearer_header(bearer_header, instance_keys, config):
  """ decode_bearer_header decodes the given bearer header that contains an encoded JWT with both
      a Key ID as well as the signed JWT and returns the decoded and validated JWT. On any error,
      raises an InvalidBearerTokenException with the reason for failure.
  """
  # Extract the jwt token from the header
  match = jwtutil.TOKEN_REGEX.match(bearer_header)
  if match is None:
    raise InvalidBearerTokenException('Invalid bearer token format')

  encoded_jwt = match.group(1)
  logger.debug('encoded JWT: %s', encoded_jwt)
  return decode_bearer_token(encoded_jwt, instance_keys, config)


def decode_bearer_token(bearer_token, instance_keys, config):
  """ decode_bearer_token decodes the given bearer token that contains both a Key ID as well as the
      encoded JWT and returns the decoded and validated JWT. On any error, raises an
      InvalidBearerTokenException with the reason for failure.
  """
  # Decode the key ID.
  headers = jwt.get_unverified_header(bearer_token)
  kid = headers.get('kid', None)
  if kid is None:
    logger.error('Missing kid header on encoded JWT: %s', bearer_token)
    raise InvalidBearerTokenException('Missing kid header')

  # Find the matching public key.
  public_key = instance_keys.get_service_key_public_key(kid)
  if public_key is None:
    logger.error('Could not find requested service key %s', kid)
    raise InvalidBearerTokenException('Unknown service key')

  # Load the JWT returned.
  try:
    expected_issuer = instance_keys.service_name
    audience = config['SERVER_HOSTNAME']
    max_signed_s = config.get('REGISTRY_JWT_AUTH_MAX_FRESH_S', 3660)
    max_exp = jwtutil.exp_max_s_option(max_signed_s)
    payload = jwtutil.decode(bearer_token, public_key, algorithms=[ALGORITHM], audience=audience,
                             issuer=expected_issuer, options=max_exp, leeway=JWT_CLOCK_SKEW_SECONDS)
  except jwtutil.InvalidTokenError as ite:
    logger.exception('Invalid token reason: %s', ite)
    raise InvalidBearerTokenException(ite)

  if not 'sub' in payload:
    raise InvalidBearerTokenException('Missing sub field in JWT')

  return payload


def generate_bearer_token(audience, subject, context, access, lifetime_s, instance_keys):
  """ Generates a registry bearer token (without the 'Bearer ' portion) based on the given
      information.
  """
  return _generate_jwt_object(audience, subject, context, access, lifetime_s,
                              instance_keys.service_name, instance_keys.local_key_id,
                              instance_keys.local_private_key)


def _generate_jwt_object(audience, subject, context, access, lifetime_s, issuer, key_id,
                         private_key):
  """ Generates a compact encoded JWT with the values specified. """
  token_data = {
    'iss': issuer,
    'aud': audience,
    'nbf': int(time.time()),
    'iat': int(time.time()),
    'exp': int(time.time() + lifetime_s),
    'sub': subject,
    'access': access,
    'context': context,
  }

  token_headers = {
    'kid': key_id,
  }

  return jwt.encode(token_data, private_key, ALGORITHM, headers=token_headers)


def build_context_and_subject(user, token, oauthtoken):
  """ Builds the custom context field for the JWT signed token and returns it,
      along with the subject for the JWT signed token. """
  if oauthtoken:
    context = {
      'kind': 'oauth',
      'user': user.username,
      'oauth': oauthtoken.uuid,
    }

    return (context, user.username)

  if user:
    context = {
      'kind': 'user',
      'user': user.username,
    }
    return (context, user.username)

  if token:
    context = {
      'kind': 'token',
      'token': token.code,
    }
    return (context, None)

  context = {
    'kind': 'anonymous',
  }
  return (context, ANONYMOUS_SUB)