import logging
import os
import base64
import hmac

from functools import wraps
from flask import session, request

from app import app
from auth.auth_context import get_validated_oauth_token
from util.http import abort


logger = logging.getLogger(__name__)

OAUTH_CSRF_TOKEN_NAME = '_oauth_csrf_token'
_QUAY_CSRF_TOKEN_NAME = '_csrf_token'

def generate_csrf_token(session_token_name=_QUAY_CSRF_TOKEN_NAME):
  """ If not present in the session, generates a new CSRF token with the given name
      and places it into the session. Returns the generated token.
  """
  if session_token_name not in session:
    session[session_token_name] = base64.b64encode(os.urandom(48))

  return session[session_token_name]


def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME,
                request_token_name=_QUAY_CSRF_TOKEN_NAME):
  """ Verifies that the CSRF token with the given name is found in the session and
      that the matching token is found in the request args or values.
  """
  token = str(session.get(session_token_name, ''))
  found_token = str(request.values.get(request_token_name, ''))

  if not token or not found_token or not hmac.compare_digest(token, found_token):
    msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s'
    logger.error(msg, session_token_name, token, request_token_name, found_token)
    abort(403, message='CSRF token was invalid or missing.')


def csrf_protect(session_token_name=_QUAY_CSRF_TOKEN_NAME,
                 request_token_name=_QUAY_CSRF_TOKEN_NAME,
                 all_methods=False):
  def inner(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
      if get_validated_oauth_token() is None:
        if all_methods or (request.method != "GET" and request.method != "HEAD"):
          verify_csrf(session_token_name, request_token_name)

      return func(*args, **kwargs)
    return wrapper
  return inner


app.jinja_env.globals['csrf_token'] = generate_csrf_token