Move conduct_call into a common test lib for all endpoints

This commit is contained in:
Joseph Schorr 2017-06-28 11:38:36 +03:00
parent 91d2cb1ec1
commit 2f018046ec
12 changed files with 83 additions and 66 deletions

View file

@ -1,58 +1,10 @@
import datetime
import json
from contextlib import contextmanager
from data import model
from endpoints.test.shared import conduct_call
from endpoints.api import api
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
@contextmanager
def client_with_identity(auth_username, client):
with client.session_transaction() as sess:
if auth_username and auth_username is not None:
loaded = model.user.get_user(auth_username)
sess['user_id'] = loaded.uuid
sess['login_time'] = datetime.datetime.now()
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
else:
sess['user_id'] = 'anonymous'
yield client
with client.session_transaction() as sess:
sess['user_id'] = None
sess['login_time'] = None
sess[CSRF_TOKEN_KEY] = None
def add_csrf_param(params):
""" Returns a params dict with the CSRF parameter added. """
params = params or {}
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
return params
def conduct_api_call(client, resource, method, params, body=None, expected_code=200):
""" Conducts an API call to the given resource via the given client, and ensures its returned
status matches the code given.
Returns the response.
"""
params = add_csrf_param(params)
final_url = api.url_for(resource, **params)
headers = {}
headers.update({"Content-Type": "application/json"})
if body is not None:
body = json.dumps(body)
rv = client.open(final_url, method=method, data=body, headers=headers)
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
rv.data)
assert rv.status_code == expected_code, msg
return rv
return conduct_call(client, resource, api.url_for, method, params, body, expected_code)

View file

@ -16,7 +16,8 @@ from endpoints.api.trigger import (BuildTriggerList, BuildTrigger, BuildTriggerS
BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger,
TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources,
BuildTriggerSourceNamespaces)
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity
from test.fixtures import *
BUILD_ARGS = {'build_uuid': '1234'}

View file

@ -2,8 +2,9 @@ import pytest
from data import model
from endpoints.api import api
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.organization import Organization
from endpoints.test.shared import client_with_identity
from test.fixtures import *
@pytest.mark.parametrize('expiration, expected_code', [

View file

@ -2,8 +2,9 @@ import pytest
from mock import patch, ANY, MagicMock
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.repository import RepositoryTrust, Repository
from endpoints.test.shared import client_with_identity
from features import FeatureNameValue
from test.fixtures import *

View file

@ -4,7 +4,8 @@ from playhouse.test_utils import assert_query_count
from data.model import _basequery
from endpoints.api.search import ConductRepositorySearch, ConductSearch
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity
from test.fixtures import *
@pytest.mark.parametrize('query, expected_query_count', [

View file

@ -4,12 +4,13 @@ from flask_principal import AnonymousIdentity
from endpoints.api import api
from endpoints.api.repositorynotification import RepositoryNotification
from endpoints.api.team import OrganizationTeamSyncing
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.repository import RepositoryTrust
from endpoints.api.signing import RepositorySignatures
from endpoints.api.search import ConductRepositorySearch
from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource
from endpoints.api.superuser import SuperUserRepositoryBuildStatus
from endpoints.test.shared import client_with_identity
from test.fixtures import *

View file

@ -3,8 +3,9 @@ import pytest
from collections import Counter
from mock import patch
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.signing import RepositorySignatures
from endpoints.test.shared import client_with_identity
from test.fixtures import *

View file

@ -2,8 +2,10 @@ import pytest
from mock import patch, Mock
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.tag import RepositoryTag, RestoreTag
from endpoints.test.shared import client_with_identity
from features import FeatureNameValue
from test.fixtures import *

View file

@ -4,9 +4,11 @@ from mock import patch
from data import model
from endpoints.api import api
from endpoints.api.test.shared import client_with_identity, conduct_api_call
from endpoints.api.test.shared import conduct_api_call
from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList
from endpoints.api.organization import Organization
from endpoints.test.shared import client_with_identity
from test.test_ldap import mock_ldap
from test.fixtures import *

View file

@ -5,7 +5,7 @@ from flask import url_for
from data import model
from endpoints.appr.registry import appr_bp, blobs
from endpoints.api.test.shared import client_with_identity
from endpoints.test.shared import client_with_identity
from test.fixtures import *
BLOB_ARGS = {'digest': 'abcd1235'}

View file

55
endpoints/test/shared.py Normal file
View file

@ -0,0 +1,55 @@
import datetime
import json
from contextlib import contextmanager
from data import model
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
@contextmanager
def client_with_identity(auth_username, client):
with client.session_transaction() as sess:
if auth_username and auth_username is not None:
loaded = model.user.get_user(auth_username)
sess['user_id'] = loaded.uuid
sess['login_time'] = datetime.datetime.now()
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
else:
sess['user_id'] = 'anonymous'
yield client
with client.session_transaction() as sess:
sess['user_id'] = None
sess['login_time'] = None
sess[CSRF_TOKEN_KEY] = None
def add_csrf_param(params):
""" Returns a params dict with the CSRF parameter added. """
params = params or {}
if not CSRF_TOKEN_KEY in params:
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
return params
def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200, headers=None):
""" Conducts a call to a Flask endpoint. """
params = add_csrf_param(params)
final_url = url_for(resource, **params)
headers = headers or {}
headers.update({"Content-Type": "application/json"})
if body is not None:
body = json.dumps(body)
rv = client.open(final_url, method=method, data=body, headers=headers)
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
rv.data)
assert rv.status_code == expected_code, msg
return rv