Fix the test_api_security tests for csrf.

This commit is contained in:
jakedt 2014-03-25 14:53:27 -04:00
parent 219fbd6950
commit 26a57d0c21

View file

@ -1,6 +1,9 @@
import unittest
import json
from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs
from app import app
from initdb import setup_database_for_testing, finished_database_for_testing
from endpoints.api import api_bp, api
@ -37,36 +40,53 @@ from endpoints.api.permission import (RepositoryUserPermission, RepositoryTeamPe
app.register_blueprint(api_bp, url_prefix='/api')
CSRF_TOKEN_KEY = '_csrf_token'
CSRF_TOKEN = '123csrfforme'
class ApiTestCase(unittest.TestCase):
@staticmethod
def _add_csrf(without_csrf):
parts = urlparse(without_csrf)
query = parse_qs(parts[4])
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
def _set_url(self, resource, **url_params):
with app.test_request_context():
self.url = api.url_for(resource, **url_params)
def _run_test(self, method, expected_status, auth_username=None, request_body=None):
with app.test_client() as client:
if auth_username:
# Temporarily remove the teardown functions
teardown_funcs = []
if None in app.teardown_request_funcs:
teardown_funcs = app.teardown_request_funcs[None]
app.teardown_request_funcs[None] = []
# Temporarily remove the teardown functions
teardown_funcs = []
if None in app.teardown_request_funcs:
teardown_funcs = app.teardown_request_funcs[None]
app.teardown_request_funcs[None] = []
with client.session_transaction() as sess:
with client.session_transaction() as sess:
if auth_username:
sess['user_id'] = auth_username
sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
# Restore the teardown functions
app.teardown_request_funcs[None] = teardown_funcs
# Restore the teardown functions
app.teardown_request_funcs[None] = teardown_funcs
open_kwargs = {
'method': method
}
final_url = self.url
if method != 'GET' and method != 'HEAD':
final_url = self._add_csrf(self.url)
open_kwargs.update({
'data': json.dumps(request_body),
'content_type': 'application/json',
})
rv = client.open(self.url, **open_kwargs)
msg = '%s %s: %s expected: %s' % (method, self.url, rv.status_code, expected_status)
rv = client.open(final_url, **open_kwargs)
msg = '%s %s: %s expected: %s' % (method, final_url, rv.status_code, expected_status)
self.assertEqual(rv.status_code, expected_status, msg)
def setUp(self):