diff --git a/endpoints/csrf.py b/endpoints/csrf.py index 28fee6c74..b2dbfcff1 100644 --- a/endpoints/csrf.py +++ b/endpoints/csrf.py @@ -1,9 +1,10 @@ import logging import os import base64 +import hmac -from flask import session, request from functools import wraps +from flask import session, request from app import app from auth.auth_context import get_validated_oauth_token @@ -30,9 +31,10 @@ def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME, """ Verifies that the CSRF token with the given name is found in the session and that the matching token is found in the request args or values. """ - token = session.get(session_token_name, None) - found_token = request.values.get(request_token_name, None) - if not token or token != found_token: + token = str(session.get(session_token_name, '')) + found_token = str(request.values.get(request_token_name, '')) + + if not token or not found_token or not hmac.compare_digest(token, found_token): msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s' logger.error(msg, session_token_name, token, request_token_name, found_token) abort(403, message='CSRF token was invalid or missing.') diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 1bfb8f9dc..f89003972 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -124,9 +124,11 @@ class ApiTestCase(unittest.TestCase): query[CSRF_TOKEN_KEY] = CSRF_TOKEN return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:])) - def url_for(self, resource_name, params={}): + def url_for(self, resource_name, params=None, skip_csrf=False): + params = params or {} url = api.url_for(resource_name, **params) - url = ApiTestCase._add_csrf(url) + if not skip_csrf: + url = ApiTestCase._add_csrf(url) return url def setUp(self): @@ -211,8 +213,8 @@ class ApiTestCase(unittest.TestCase): return parsed def putJsonResponse(self, resource_name, params={}, data={}, - expected_code=200): - rv = self.app.put(self.url_for(resource_name, params), + expected_code=200, skip_csrf=False): + rv = self.app.put(self.url_for(resource_name, params, skip_csrf), data=py_json.dumps(data), headers={"Content-Type": "application/json"}) @@ -246,15 +248,35 @@ class TestCSRFFailure(ApiTestCase): self.login(READ_ACCESS_USER) # Make sure a simple post call succeeds. - self.putJsonResponse(User, - data=dict(password='newpasswordiscool')) + self.putJsonResponse(User, data=dict(password='newpasswordiscool')) # Change the session's CSRF token. self.setCsrfToken('someinvalidtoken') # Verify that the call now fails. - self.putJsonResponse(User, - data=dict(password='newpasswordiscool'), + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403) + + def test_csrf_failure_empty_token(self): + self.login(READ_ACCESS_USER) + + # Change the session's CSRF token to be empty. + self.setCsrfToken('') + + # Verify that the call now fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403) + + def test_csrf_failure_missing_token(self): + self.login(READ_ACCESS_USER) + + # Make sure a simple post call without a token at all fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True, + expected_code=403) + + # Change the session's CSRF token to be empty. + self.setCsrfToken('') + + # Verify that the call still fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True, expected_code=403)