Move conduct_call into a common test lib for all endpoints
This commit is contained in:
parent
91d2cb1ec1
commit
2f018046ec
12 changed files with 83 additions and 66 deletions
|
@ -1,58 +1,10 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from contextlib import contextmanager
|
||||
from data import model
|
||||
from endpoints.test.shared import conduct_call
|
||||
from endpoints.api import api
|
||||
|
||||
CSRF_TOKEN_KEY = '_csrf_token'
|
||||
CSRF_TOKEN = '123csrfforme'
|
||||
|
||||
|
||||
@contextmanager
|
||||
def client_with_identity(auth_username, client):
|
||||
with client.session_transaction() as sess:
|
||||
if auth_username and auth_username is not None:
|
||||
loaded = model.user.get_user(auth_username)
|
||||
sess['user_id'] = loaded.uuid
|
||||
sess['login_time'] = datetime.datetime.now()
|
||||
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||
else:
|
||||
sess['user_id'] = 'anonymous'
|
||||
|
||||
yield client
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
sess['user_id'] = None
|
||||
sess['login_time'] = None
|
||||
sess[CSRF_TOKEN_KEY] = None
|
||||
|
||||
|
||||
def add_csrf_param(params):
|
||||
""" Returns a params dict with the CSRF parameter added. """
|
||||
params = params or {}
|
||||
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||
return params
|
||||
|
||||
|
||||
def conduct_api_call(client, resource, method, params, body=None, expected_code=200):
|
||||
""" Conducts an API call to the given resource via the given client, and ensures its returned
|
||||
status matches the code given.
|
||||
|
||||
Returns the response.
|
||||
"""
|
||||
params = add_csrf_param(params)
|
||||
|
||||
final_url = api.url_for(resource, **params)
|
||||
|
||||
headers = {}
|
||||
headers.update({"Content-Type": "application/json"})
|
||||
|
||||
if body is not None:
|
||||
body = json.dumps(body)
|
||||
|
||||
rv = client.open(final_url, method=method, data=body, headers=headers)
|
||||
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
|
||||
rv.data)
|
||||
assert rv.status_code == expected_code, msg
|
||||
return rv
|
||||
return conduct_call(client, resource, api.url_for, method, params, body, expected_code)
|
||||
|
|
|
@ -16,7 +16,8 @@ from endpoints.api.trigger import (BuildTriggerList, BuildTrigger, BuildTriggerS
|
|||
BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger,
|
||||
TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources,
|
||||
BuildTriggerSourceNamespaces)
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.test.shared import client_with_identity
|
||||
from test.fixtures import *
|
||||
|
||||
BUILD_ARGS = {'build_uuid': '1234'}
|
||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
|||
|
||||
from data import model
|
||||
from endpoints.api import api
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.organization import Organization
|
||||
from endpoints.test.shared import client_with_identity
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('expiration, expected_code', [
|
||||
|
|
|
@ -2,8 +2,9 @@ import pytest
|
|||
|
||||
from mock import patch, ANY, MagicMock
|
||||
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.repository import RepositoryTrust, Repository
|
||||
from endpoints.test.shared import client_with_identity
|
||||
from features import FeatureNameValue
|
||||
|
||||
from test.fixtures import *
|
||||
|
|
|
@ -4,7 +4,8 @@ from playhouse.test_utils import assert_query_count
|
|||
|
||||
from data.model import _basequery
|
||||
from endpoints.api.search import ConductRepositorySearch, ConductSearch
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.test.shared import client_with_identity
|
||||
from test.fixtures import *
|
||||
|
||||
@pytest.mark.parametrize('query, expected_query_count', [
|
||||
|
|
|
@ -4,12 +4,13 @@ from flask_principal import AnonymousIdentity
|
|||
from endpoints.api import api
|
||||
from endpoints.api.repositorynotification import RepositoryNotification
|
||||
from endpoints.api.team import OrganizationTeamSyncing
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.repository import RepositoryTrust
|
||||
from endpoints.api.signing import RepositorySignatures
|
||||
from endpoints.api.search import ConductRepositorySearch
|
||||
from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource
|
||||
from endpoints.api.superuser import SuperUserRepositoryBuildStatus
|
||||
from endpoints.test.shared import client_with_identity
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
|
|
|
@ -3,8 +3,9 @@ import pytest
|
|||
from collections import Counter
|
||||
from mock import patch
|
||||
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.signing import RepositorySignatures
|
||||
from endpoints.test.shared import client_with_identity
|
||||
|
||||
from test.fixtures import *
|
||||
|
||||
|
|
|
@ -2,8 +2,10 @@ import pytest
|
|||
|
||||
from mock import patch, Mock
|
||||
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.tag import RepositoryTag, RestoreTag
|
||||
from endpoints.test.shared import client_with_identity
|
||||
|
||||
from features import FeatureNameValue
|
||||
|
||||
from test.fixtures import *
|
||||
|
|
|
@ -4,9 +4,11 @@ from mock import patch
|
|||
|
||||
from data import model
|
||||
from endpoints.api import api
|
||||
from endpoints.api.test.shared import client_with_identity, conduct_api_call
|
||||
from endpoints.api.test.shared import conduct_api_call
|
||||
from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList
|
||||
from endpoints.api.organization import Organization
|
||||
from endpoints.test.shared import client_with_identity
|
||||
|
||||
from test.test_ldap import mock_ldap
|
||||
|
||||
from test.fixtures import *
|
||||
|
|
|
@ -5,7 +5,7 @@ from flask import url_for
|
|||
|
||||
from data import model
|
||||
from endpoints.appr.registry import appr_bp, blobs
|
||||
from endpoints.api.test.shared import client_with_identity
|
||||
from endpoints.test.shared import client_with_identity
|
||||
from test.fixtures import *
|
||||
|
||||
BLOB_ARGS = {'digest': 'abcd1235'}
|
||||
|
|
0
endpoints/test/__init__.py
Normal file
0
endpoints/test/__init__.py
Normal file
55
endpoints/test/shared.py
Normal file
55
endpoints/test/shared.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from contextlib import contextmanager
|
||||
from data import model
|
||||
|
||||
CSRF_TOKEN_KEY = '_csrf_token'
|
||||
CSRF_TOKEN = '123csrfforme'
|
||||
|
||||
@contextmanager
|
||||
def client_with_identity(auth_username, client):
|
||||
with client.session_transaction() as sess:
|
||||
if auth_username and auth_username is not None:
|
||||
loaded = model.user.get_user(auth_username)
|
||||
sess['user_id'] = loaded.uuid
|
||||
sess['login_time'] = datetime.datetime.now()
|
||||
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||
else:
|
||||
sess['user_id'] = 'anonymous'
|
||||
|
||||
yield client
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
sess['user_id'] = None
|
||||
sess['login_time'] = None
|
||||
sess[CSRF_TOKEN_KEY] = None
|
||||
|
||||
|
||||
def add_csrf_param(params):
|
||||
""" Returns a params dict with the CSRF parameter added. """
|
||||
params = params or {}
|
||||
|
||||
if not CSRF_TOKEN_KEY in params:
|
||||
params[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200, headers=None):
|
||||
""" Conducts a call to a Flask endpoint. """
|
||||
params = add_csrf_param(params)
|
||||
|
||||
final_url = url_for(resource, **params)
|
||||
|
||||
headers = headers or {}
|
||||
headers.update({"Content-Type": "application/json"})
|
||||
|
||||
if body is not None:
|
||||
body = json.dumps(body)
|
||||
|
||||
rv = client.open(final_url, method=method, data=body, headers=headers)
|
||||
msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
|
||||
rv.data)
|
||||
assert rv.status_code == expected_code, msg
|
||||
return rv
|
Reference in a new issue