Move conduct_call into a common test lib for all endpoints

This commit is contained in:
Joseph Schorr 2017-06-28 11:38:36 +03:00
parent 91d2cb1ec1
commit 2f018046ec
12 changed files with 83 additions and 66 deletions

View file

@ -1,58 +1,10 @@
import datetime from endpoints.test.shared import conduct_call
import json
from contextlib import contextmanager
from data import model
from endpoints.api import api from endpoints.api import api
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
@contextmanager
def client_with_identity(auth_username, client):
with client.session_transaction() as sess:
if auth_username and auth_username is not None:
loaded = model.user.get_user(auth_username)
sess['user_id'] = loaded.uuid
sess['login_time'] = datetime.datetime.now()
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
else:
sess['user_id'] = 'anonymous'
yield client
with client.session_transaction() as sess:
sess['user_id'] = None
sess['login_time'] = None
sess[CSRF_TOKEN_KEY] = None
def add_csrf_param(params):
""" Returns a params dict with the CSRF parameter added. """
params = params or {}
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
return params
def conduct_api_call(client, resource, method, params, body=None, expected_code=200): def conduct_api_call(client, resource, method, params, body=None, expected_code=200):
""" Conducts an API call to the given resource via the given client, and ensures its returned """ Conducts an API call to the given resource via the given client, and ensures its returned
status matches the code given. status matches the code given.
Returns the response. Returns the response.
""" """
params = add_csrf_param(params) return conduct_call(client, resource, api.url_for, method, params, body, expected_code)
final_url = api.url_for(resource, **params)
headers = {}
headers.update({"Content-Type": "application/json"})
if body is not None:
body = json.dumps(body)
rv = client.open(final_url, method=method, data=body, headers=headers)
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
rv.data)
assert rv.status_code == expected_code, msg
return rv

View file

@ -16,7 +16,8 @@ from endpoints.api.trigger import (BuildTriggerList, BuildTrigger, BuildTriggerS
BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger, BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger,
TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources, TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources,
BuildTriggerSourceNamespaces) BuildTriggerSourceNamespaces)
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *
BUILD_ARGS = {'build_uuid': '1234'} BUILD_ARGS = {'build_uuid': '1234'}

View file

@ -2,8 +2,9 @@ import pytest
from data import model from data import model
from endpoints.api import api from endpoints.api import api
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.organization import Organization from endpoints.api.organization import Organization
from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *
@pytest.mark.parametrize('expiration, expected_code', [ @pytest.mark.parametrize('expiration, expected_code', [

View file

@ -2,8 +2,9 @@ import pytest
from mock import patch, ANY, MagicMock from mock import patch, ANY, MagicMock
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.repository import RepositoryTrust, Repository from endpoints.api.repository import RepositoryTrust, Repository
from endpoints.test.shared import client_with_identity
from features import FeatureNameValue from features import FeatureNameValue
from test.fixtures import * from test.fixtures import *

View file

@ -4,7 +4,8 @@ from playhouse.test_utils import assert_query_count
from data.model import _basequery from data.model import _basequery
from endpoints.api.search import ConductRepositorySearch, ConductSearch from endpoints.api.search import ConductRepositorySearch, ConductSearch
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *
@pytest.mark.parametrize('query, expected_query_count', [ @pytest.mark.parametrize('query, expected_query_count', [

View file

@ -4,12 +4,13 @@ from flask_principal import AnonymousIdentity
from endpoints.api import api from endpoints.api import api
from endpoints.api.repositorynotification import RepositoryNotification from endpoints.api.repositorynotification import RepositoryNotification
from endpoints.api.team import OrganizationTeamSyncing from endpoints.api.team import OrganizationTeamSyncing
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.repository import RepositoryTrust from endpoints.api.repository import RepositoryTrust
from endpoints.api.signing import RepositorySignatures from endpoints.api.signing import RepositorySignatures
from endpoints.api.search import ConductRepositorySearch from endpoints.api.search import ConductRepositorySearch
from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource
from endpoints.api.superuser import SuperUserRepositoryBuildStatus from endpoints.api.superuser import SuperUserRepositoryBuildStatus
from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *

View file

@ -3,8 +3,9 @@ import pytest
from collections import Counter from collections import Counter
from mock import patch from mock import patch
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.signing import RepositorySignatures from endpoints.api.signing import RepositorySignatures
from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *

View file

@ -2,8 +2,10 @@ import pytest
from mock import patch, Mock from mock import patch, Mock
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.tag import RepositoryTag, RestoreTag from endpoints.api.tag import RepositoryTag, RestoreTag
from endpoints.test.shared import client_with_identity
from features import FeatureNameValue from features import FeatureNameValue
from test.fixtures import * from test.fixtures import *

View file

@ -4,9 +4,11 @@ from mock import patch
from data import model from data import model
from endpoints.api import api from endpoints.api import api
from endpoints.api.test.shared import client_with_identity, conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList
from endpoints.api.organization import Organization from endpoints.api.organization import Organization
from endpoints.test.shared import client_with_identity
from test.test_ldap import mock_ldap from test.test_ldap import mock_ldap
from test.fixtures import * from test.fixtures import *

View file

@ -5,7 +5,7 @@ from flask import url_for
from data import model from data import model
from endpoints.appr.registry import appr_bp, blobs from endpoints.appr.registry import appr_bp, blobs
from endpoints.api.test.shared import client_with_identity from endpoints.test.shared import client_with_identity
from test.fixtures import * from test.fixtures import *
BLOB_ARGS = {'digest': 'abcd1235'} BLOB_ARGS = {'digest': 'abcd1235'}

View file

55
endpoints/test/shared.py Normal file
View file

@ -0,0 +1,55 @@
import datetime
import json
from contextlib import contextmanager
from data import model
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
@contextmanager
def client_with_identity(auth_username, client):
with client.session_transaction() as sess:
if auth_username and auth_username is not None:
loaded = model.user.get_user(auth_username)
sess['user_id'] = loaded.uuid
sess['login_time'] = datetime.datetime.now()
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
else:
sess['user_id'] = 'anonymous'
yield client
with client.session_transaction() as sess:
sess['user_id'] = None
sess['login_time'] = None
sess[CSRF_TOKEN_KEY] = None
def add_csrf_param(params):
""" Returns a params dict with the CSRF parameter added. """
params = params or {}
if not CSRF_TOKEN_KEY in params:
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
return params
def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200, headers=None):
""" Conducts a call to a Flask endpoint. """
params = add_csrf_param(params)
final_url = url_for(resource, **params)
headers = headers or {}
headers.update({"Content-Type": "application/json"})
if body is not None:
body = json.dumps(body)
rv = client.open(final_url, method=method, data=body, headers=headers)
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
rv.data)
assert rv.status_code == expected_code, msg
return rv