From 2f018046ecf726689158510bf36263f840dd7c8a Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Wed, 28 Jun 2017 11:38:36 +0300 Subject: [PATCH] Move conduct_call into a common test lib for all endpoints --- endpoints/api/test/shared.py | 52 +----------------- endpoints/api/test/test_disallow_for_apps.py | 3 +- endpoints/api/test/test_organization.py | 3 +- endpoints/api/test/test_repository.py | 7 +-- endpoints/api/test/test_search.py | 3 +- endpoints/api/test/test_security.py | 3 +- endpoints/api/test/test_signing.py | 13 ++--- endpoints/api/test/test_tag.py | 4 +- endpoints/api/test/test_team.py | 4 +- endpoints/appr/test/test_api_security.py | 2 +- endpoints/test/__init__.py | 0 endpoints/test/shared.py | 55 ++++++++++++++++++++ 12 files changed, 83 insertions(+), 66 deletions(-) create mode 100644 endpoints/test/__init__.py create mode 100644 endpoints/test/shared.py diff --git a/endpoints/api/test/shared.py b/endpoints/api/test/shared.py index 3d1f0cffa..1d35cdbc5 100644 --- a/endpoints/api/test/shared.py +++ b/endpoints/api/test/shared.py @@ -1,58 +1,10 @@ -import datetime -import json - -from contextlib import contextmanager -from data import model +from endpoints.test.shared import conduct_call from endpoints.api import api -CSRF_TOKEN_KEY = '_csrf_token' -CSRF_TOKEN = '123csrfforme' - - -@contextmanager -def client_with_identity(auth_username, client): - with client.session_transaction() as sess: - if auth_username and auth_username is not None: - loaded = model.user.get_user(auth_username) - sess['user_id'] = loaded.uuid - sess['login_time'] = datetime.datetime.now() - sess[CSRF_TOKEN_KEY] = CSRF_TOKEN - else: - sess['user_id'] = 'anonymous' - - yield client - - with client.session_transaction() as sess: - sess['user_id'] = None - sess['login_time'] = None - sess[CSRF_TOKEN_KEY] = None - - -def add_csrf_param(params): - """ Returns a params dict with the CSRF parameter added. """ - params = params or {} - params[CSRF_TOKEN_KEY] = CSRF_TOKEN - return params - - def conduct_api_call(client, resource, method, params, body=None, expected_code=200): """ Conducts an API call to the given resource via the given client, and ensures its returned status matches the code given. Returns the response. """ - params = add_csrf_param(params) - - final_url = api.url_for(resource, **params) - - headers = {} - headers.update({"Content-Type": "application/json"}) - - if body is not None: - body = json.dumps(body) - - rv = client.open(final_url, method=method, data=body, headers=headers) - msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code, - rv.data) - assert rv.status_code == expected_code, msg - return rv + return conduct_call(client, resource, api.url_for, method, params, body, expected_code) diff --git a/endpoints/api/test/test_disallow_for_apps.py b/endpoints/api/test/test_disallow_for_apps.py index 6de35c03b..b9112c291 100644 --- a/endpoints/api/test/test_disallow_for_apps.py +++ b/endpoints/api/test/test_disallow_for_apps.py @@ -16,7 +16,8 @@ from endpoints.api.trigger import (BuildTriggerList, BuildTrigger, BuildTriggerS BuildTriggerActivate, BuildTriggerAnalyze, ActivateBuildTrigger, TriggerBuildList, BuildTriggerFieldValues, BuildTriggerSources, BuildTriggerSourceNamespaces) -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call +from endpoints.test.shared import client_with_identity from test.fixtures import * BUILD_ARGS = {'build_uuid': '1234'} diff --git a/endpoints/api/test/test_organization.py b/endpoints/api/test/test_organization.py index 65b9a85d4..9a6525113 100644 --- a/endpoints/api/test/test_organization.py +++ b/endpoints/api/test/test_organization.py @@ -2,8 +2,9 @@ import pytest from data import model from endpoints.api import api -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.organization import Organization +from endpoints.test.shared import client_with_identity from test.fixtures import * @pytest.mark.parametrize('expiration, expected_code', [ diff --git a/endpoints/api/test/test_repository.py b/endpoints/api/test/test_repository.py index d110f5760..999beb00d 100644 --- a/endpoints/api/test/test_repository.py +++ b/endpoints/api/test/test_repository.py @@ -2,8 +2,9 @@ import pytest from mock import patch, ANY, MagicMock -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.repository import RepositoryTrust, Repository +from endpoints.test.shared import client_with_identity from features import FeatureNameValue from test.fixtures import * @@ -52,8 +53,8 @@ def test_signing_disabled(client): params = {'repository': 'devtable/simple'} response = conduct_api_call(cl, Repository, 'GET', params).json assert not response['trust_enabled'] - - + + def test_sni_support(): import ssl assert ssl.HAS_SNI diff --git a/endpoints/api/test/test_search.py b/endpoints/api/test/test_search.py index 4efba0841..1cca8d548 100644 --- a/endpoints/api/test/test_search.py +++ b/endpoints/api/test/test_search.py @@ -4,7 +4,8 @@ from playhouse.test_utils import assert_query_count from data.model import _basequery from endpoints.api.search import ConductRepositorySearch, ConductSearch -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call +from endpoints.test.shared import client_with_identity from test.fixtures import * @pytest.mark.parametrize('query, expected_query_count', [ diff --git a/endpoints/api/test/test_security.py b/endpoints/api/test/test_security.py index 40140b6fa..68039aed7 100644 --- a/endpoints/api/test/test_security.py +++ b/endpoints/api/test/test_security.py @@ -4,12 +4,13 @@ from flask_principal import AnonymousIdentity from endpoints.api import api from endpoints.api.repositorynotification import RepositoryNotification from endpoints.api.team import OrganizationTeamSyncing -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.repository import RepositoryTrust from endpoints.api.signing import RepositorySignatures from endpoints.api.search import ConductRepositorySearch from endpoints.api.superuser import SuperUserRepositoryBuildLogs, SuperUserRepositoryBuildResource from endpoints.api.superuser import SuperUserRepositoryBuildStatus +from endpoints.test.shared import client_with_identity from test.fixtures import * diff --git a/endpoints/api/test/test_signing.py b/endpoints/api/test/test_signing.py index 31f37d632..e941cee56 100644 --- a/endpoints/api/test/test_signing.py +++ b/endpoints/api/test/test_signing.py @@ -3,8 +3,9 @@ import pytest from collections import Counter from mock import patch -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.signing import RepositorySignatures +from endpoints.test.shared import client_with_identity from test.fixtures import * @@ -14,21 +15,21 @@ VALID_TARGETS_MAP = { "latest": { "hashes": { "sha256": "2Q8GLEgX62VBWeL76axFuDj/Z1dd6Zhx0ZDM6kNwPkQ=" - }, + }, "length": 2111 } - }, + }, "expiration": "2020-05-22T10:26:46.618176424-04:00" - }, + }, "targets": { "targets": { "latest": { "hashes": { "sha256": "2Q8GLEgX62VBWeL76axFuDj/Z1dd6Zhx0ZDM6kNwPkQ=" - }, + }, "length": 2111 } - }, + }, "expiration": "2020-05-22T10:26:01.953414888-04:00"} } diff --git a/endpoints/api/test/test_tag.py b/endpoints/api/test/test_tag.py index 0c80ef4ee..a94261fc4 100644 --- a/endpoints/api/test/test_tag.py +++ b/endpoints/api/test/test_tag.py @@ -2,8 +2,10 @@ import pytest from mock import patch, Mock -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.tag import RepositoryTag, RestoreTag +from endpoints.test.shared import client_with_identity + from features import FeatureNameValue from test.fixtures import * diff --git a/endpoints/api/test/test_team.py b/endpoints/api/test/test_team.py index c40f8f199..9a17a36e4 100644 --- a/endpoints/api/test/test_team.py +++ b/endpoints/api/test/test_team.py @@ -4,9 +4,11 @@ from mock import patch from data import model from endpoints.api import api -from endpoints.api.test.shared import client_with_identity, conduct_api_call +from endpoints.api.test.shared import conduct_api_call from endpoints.api.team import OrganizationTeamSyncing, TeamMemberList from endpoints.api.organization import Organization +from endpoints.test.shared import client_with_identity + from test.test_ldap import mock_ldap from test.fixtures import * diff --git a/endpoints/appr/test/test_api_security.py b/endpoints/appr/test/test_api_security.py index e37b2f092..c3e52b30c 100644 --- a/endpoints/appr/test/test_api_security.py +++ b/endpoints/appr/test/test_api_security.py @@ -5,7 +5,7 @@ from flask import url_for from data import model from endpoints.appr.registry import appr_bp, blobs -from endpoints.api.test.shared import client_with_identity +from endpoints.test.shared import client_with_identity from test.fixtures import * BLOB_ARGS = {'digest': 'abcd1235'} diff --git a/endpoints/test/__init__.py b/endpoints/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/endpoints/test/shared.py b/endpoints/test/shared.py new file mode 100644 index 000000000..baf7de18f --- /dev/null +++ b/endpoints/test/shared.py @@ -0,0 +1,55 @@ +import datetime +import json + +from contextlib import contextmanager +from data import model + +CSRF_TOKEN_KEY = '_csrf_token' +CSRF_TOKEN = '123csrfforme' + +@contextmanager +def client_with_identity(auth_username, client): + with client.session_transaction() as sess: + if auth_username and auth_username is not None: + loaded = model.user.get_user(auth_username) + sess['user_id'] = loaded.uuid + sess['login_time'] = datetime.datetime.now() + sess[CSRF_TOKEN_KEY] = CSRF_TOKEN + else: + sess['user_id'] = 'anonymous' + + yield client + + with client.session_transaction() as sess: + sess['user_id'] = None + sess['login_time'] = None + sess[CSRF_TOKEN_KEY] = None + + +def add_csrf_param(params): + """ Returns a params dict with the CSRF parameter added. """ + params = params or {} + + if not CSRF_TOKEN_KEY in params: + params[CSRF_TOKEN_KEY] = CSRF_TOKEN + + return params + + +def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200, headers=None): + """ Conducts a call to a Flask endpoint. """ + params = add_csrf_param(params) + + final_url = url_for(resource, **params) + + headers = headers or {} + headers.update({"Content-Type": "application/json"}) + + if body is not None: + body = json.dumps(body) + + rv = client.open(final_url, method=method, data=body, headers=headers) + msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code, + rv.data) + assert rv.status_code == expected_code, msg + return rv