This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/endpoints/test/shared.py

81 lines
2.1 KiB
Python
Raw Normal View History

2019-11-12 16:09:47 +00:00
import datetime
import json
import base64
from contextlib import contextmanager
from data import model
from flask import g
from flask_principal import Identity
CSRF_TOKEN_KEY = '_csrf_token'
@contextmanager
def client_with_identity(auth_username, client):
with client.session_transaction() as sess:
if auth_username and auth_username is not None:
loaded = model.user.get_user(auth_username)
sess['user_id'] = loaded.uuid
sess['login_time'] = datetime.datetime.now()
else:
sess['user_id'] = 'anonymous'
yield client
with client.session_transaction() as sess:
sess['user_id'] = None
sess['login_time'] = None
sess[CSRF_TOKEN_KEY] = None
@contextmanager
def toggle_feature(name, enabled):
""" Context manager which temporarily toggles a feature. """
import features
previous_value = getattr(features, name)
setattr(features, name, enabled)
yield
setattr(features, name, previous_value)
def add_csrf_param(client, params):
""" Returns a params dict with the CSRF parameter added. """
params = params or {}
with client.session_transaction() as sess:
params[CSRF_TOKEN_KEY] = 'sometoken'
sess[CSRF_TOKEN_KEY] = 'sometoken'
return params
def gen_basic_auth(username, password):
""" Generates a basic auth header. """
return 'Basic ' + base64.b64encode("%s:%s" % (username, password))
def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200,
headers=None, raw_body=None):
""" Conducts a call to a Flask endpoint. """
params = add_csrf_param(client, params)
final_url = url_for(resource, **params)
headers = headers or {}
headers.update({"Content-Type": "application/json"})
if body is not None:
body = json.dumps(body)
if raw_body is not None:
body = raw_body
# Required for anonymous calls to not exception.
g.identity = Identity(None, 'none')
rv = client.open(final_url, method=method, data=body, headers=headers)
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
rv.data)
assert rv.status_code == expected_code, msg
return rv