Move conduct_call into a common test lib for all endpoints
This commit is contained in:
parent
91d2cb1ec1
commit
2f018046ec
12 changed files with 83 additions and 66 deletions
|
@ -1,58 +1,10 @@
|
||||||
import datetime
|
from endpoints.test.shared import conduct_call
|
||||||
import json
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from data import model
|
|
||||||
from endpoints.api import api
|
from endpoints.api import api
|
||||||
|
|
||||||
CSRF_TOKEN_KEY = '_csrf_token'
|
|
||||||
CSRF_TOKEN = '123csrfforme'
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def client_with_identity(auth_username, client):
|
|
||||||
with client.session_transaction() as sess:
|
|
||||||
if auth_username and auth_username is not None:
|
|
||||||
loaded = model.user.get_user(auth_username)
|
|
||||||
sess['user_id'] = loaded.uuid
|
|
||||||
sess['login_time'] = datetime.datetime.now()
|
|
||||||
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
|
||||||
else:
|
|
||||||
sess['user_id'] = 'anonymous'
|
|
||||||
|
|
||||||
yield client
|
|
||||||
|
|
||||||
with client.session_transaction() as sess:
|
|
||||||
sess['user_id'] = None
|
|
||||||
sess['login_time'] = None
|
|
||||||
sess[CSRF_TOKEN_KEY] = None
|
|
||||||
|
|
||||||
|
|
||||||
def add_csrf_param(params):
|
|
||||||
""" Returns a params dict with the CSRF parameter added. """
|
|
||||||
params = params or {}
|
|
||||||
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def conduct_api_call(client, resource, method, params, body=None, expected_code=200):
|
def conduct_api_call(client, resource, method, params, body=None, expected_code=200):
|
||||||
""" Conducts an API call to the given resource via the given client, and ensures its returned
|
""" Conducts an API call to the given resource via the given client, and ensures its returned
|
||||||
status matches the code given.
|
status matches the code given.
|
||||||
|
|
||||||
Returns the response.
|
Returns the response.
|
||||||
"""
|
"""
|
||||||
params = add_csrf_param(params)
|
return conduct_call(client, resource, api.url_for, method, params, body, expected_code)
|
||||||
|
|
||||||
final_url = api.url_for(resource, **params)
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
headers.update({"Content-Type": "application/json"})
|
|
||||||
|
|
||||||
if body is not None:
|
|
||||||
body = json.dumps(body)
|
|
||||||
|
|
||||||
rv = client.open(final_url, method=method, data=body, headers=headers)
|
|
||||||
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
|
|
||||||
rv.data)
|
|
||||||
assert rv.status_code == expected_code, msg
|
|
||||||
return rv
|
|
||||||
|
|
|
@ -16,7 +16,8 @@ from endpoints.api.trigger import (BuildTriggerList, BuildTrigger, BuildTriggerS
|
||||||
BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger,
|
BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger,
|
||||||
TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources,
|
TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources,
|
||||||
BuildTriggerSourceNamespaces)
|
BuildTriggerSourceNamespaces)
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
BUILD_ARGS = {'build_uuid': '1234'}
|
BUILD_ARGS = {'build_uuid': '1234'}
|
||||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
||||||
|
|
||||||
from data import model
|
from data import model
|
||||||
from endpoints.api import api
|
from endpoints.api import api
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.organization import Organization
|
from endpoints.api.organization import Organization
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
@pytest.mark.parametrize('expiration, expected_code', [
|
@pytest.mark.parametrize('expiration, expected_code', [
|
||||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
||||||
|
|
||||||
from mock import patch, ANY, MagicMock
|
from mock import patch, ANY, MagicMock
|
||||||
|
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.repository import RepositoryTrust, Repository
|
from endpoints.api.repository import RepositoryTrust, Repository
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
from features import FeatureNameValue
|
from features import FeatureNameValue
|
||||||
|
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
|
@ -4,7 +4,8 @@ from playhouse.test_utils import assert_query_count
|
||||||
|
|
||||||
from data.model import _basequery
|
from data.model import _basequery
|
||||||
from endpoints.api.search import ConductRepositorySearch, ConductSearch
|
from endpoints.api.search import ConductRepositorySearch, ConductSearch
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
@pytest.mark.parametrize('query, expected_query_count', [
|
@pytest.mark.parametrize('query, expected_query_count', [
|
||||||
|
|
|
@ -4,12 +4,13 @@ from flask_principal import AnonymousIdentity
|
||||||
from endpoints.api import api
|
from endpoints.api import api
|
||||||
from endpoints.api.repositorynotification import RepositoryNotification
|
from endpoints.api.repositorynotification import RepositoryNotification
|
||||||
from endpoints.api.team import OrganizationTeamSyncing
|
from endpoints.api.team import OrganizationTeamSyncing
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.repository import RepositoryTrust
|
from endpoints.api.repository import RepositoryTrust
|
||||||
from endpoints.api.signing import RepositorySignatures
|
from endpoints.api.signing import RepositorySignatures
|
||||||
from endpoints.api.search import ConductRepositorySearch
|
from endpoints.api.search import ConductRepositorySearch
|
||||||
from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource
|
from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource
|
||||||
from endpoints.api.superuser import SuperUserRepositoryBuildStatus
|
from endpoints.api.superuser import SuperUserRepositoryBuildStatus
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
|
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,9 @@ import pytest
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.signing import RepositorySignatures
|
from endpoints.api.signing import RepositorySignatures
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
|
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,10 @@ import pytest
|
||||||
|
|
||||||
from mock import patch, Mock
|
from mock import patch, Mock
|
||||||
|
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.tag import RepositoryTag, RestoreTag
|
from endpoints.api.tag import RepositoryTag, RestoreTag
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
|
|
||||||
from features import FeatureNameValue
|
from features import FeatureNameValue
|
||||||
|
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
|
@ -4,9 +4,11 @@ from mock import patch
|
||||||
|
|
||||||
from data import model
|
from data import model
|
||||||
from endpoints.api import api
|
from endpoints.api import api
|
||||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
from endpoints.api.test.shared import conduct_api_call
|
||||||
from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList
|
from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList
|
||||||
from endpoints.api.organization import Organization
|
from endpoints.api.organization import Organization
|
||||||
|
from endpoints.test.shared import client_with_identity
|
||||||
|
|
||||||
from test.test_ldap import mock_ldap
|
from test.test_ldap import mock_ldap
|
||||||
|
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
|
@ -5,7 +5,7 @@ from flask import url_for
|
||||||
|
|
||||||
from data import model
|
from data import model
|
||||||
from endpoints.appr.registry import appr_bp, blobs
|
from endpoints.appr.registry import appr_bp, blobs
|
||||||
from endpoints.api.test.shared import client_with_identity
|
from endpoints.test.shared import client_with_identity
|
||||||
from test.fixtures import *
|
from test.fixtures import *
|
||||||
|
|
||||||
BLOB_ARGS = {'digest': 'abcd1235'}
|
BLOB_ARGS = {'digest': 'abcd1235'}
|
||||||
|
|
0
endpoints/test/__init__.py
Normal file
0
endpoints/test/__init__.py
Normal file
55
endpoints/test/shared.py
Normal file
55
endpoints/test/shared.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from data import model
|
||||||
|
|
||||||
|
CSRF_TOKEN_KEY = '_csrf_token'
|
||||||
|
CSRF_TOKEN = '123csrfforme'
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def client_with_identity(auth_username, client):
|
||||||
|
with client.session_transaction() as sess:
|
||||||
|
if auth_username and auth_username is not None:
|
||||||
|
loaded = model.user.get_user(auth_username)
|
||||||
|
sess['user_id'] = loaded.uuid
|
||||||
|
sess['login_time'] = datetime.datetime.now()
|
||||||
|
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||||
|
else:
|
||||||
|
sess['user_id'] = 'anonymous'
|
||||||
|
|
||||||
|
yield client
|
||||||
|
|
||||||
|
with client.session_transaction() as sess:
|
||||||
|
sess['user_id'] = None
|
||||||
|
sess['login_time'] = None
|
||||||
|
sess[CSRF_TOKEN_KEY] = None
|
||||||
|
|
||||||
|
|
||||||
|
def add_csrf_param(params):
|
||||||
|
""" Returns a params dict with the CSRF parameter added. """
|
||||||
|
params = params or {}
|
||||||
|
|
||||||
|
if not CSRF_TOKEN_KEY in params:
|
||||||
|
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200, headers=None):
|
||||||
|
""" Conducts a call to a Flask endpoint. """
|
||||||
|
params = add_csrf_param(params)
|
||||||
|
|
||||||
|
final_url = url_for(resource, **params)
|
||||||
|
|
||||||
|
headers = headers or {}
|
||||||
|
headers.update({"Content-Type": "application/json"})
|
||||||
|
|
||||||
|
if body is not None:
|
||||||
|
body = json.dumps(body)
|
||||||
|
|
||||||
|
rv = client.open(final_url, method=method, data=body, headers=headers)
|
||||||
|
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
|
||||||
|
rv.data)
|
||||||
|
assert rv.status_code == expected_code, msg
|
||||||
|
return rv
|
Reference in a new issue