import logging
import os
import base64

from flask import session, request
from functools import wraps

from app import app
from auth.auth_context import get_validated_oauth_token
from util.http import abort


logger = logging.getLogger(__name__)


def generate_csrf_token():
  if '_csrf_token' not in session:
    session['_csrf_token'] = base64.b64encode(os.urandom(48))

  return session['_csrf_token']

def verify_csrf():
  token = session.get('_csrf_token', None)
  found_token = request.values.get('_csrf_token', None)

  if not token or token != found_token:
    msg = 'CSRF Failure. Session token was %s and request token was %s'
    logger.error(msg, token, found_token)
    abort(403, message='CSRF token was invalid or missing.')

def csrf_protect(func):
  @wraps(func)
  def wrapper(*args, **kwargs):
    oauth_token = get_validated_oauth_token()
    if oauth_token is None and request.method != "GET" and request.method != "HEAD":
      verify_csrf()

    return func(*args, **kwargs)
  return wrapper


app.jinja_env.globals['csrf_token'] = generate_csrf_token