Switch csrf token check to use compare_digest
to prevent timing attacks
Also adds some additional tests for CSRF tokens
This commit is contained in:
parent
dbdcb802b1
commit
1302fd2fbd
2 changed files with 36 additions and 12 deletions
|
@ -1,9 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
|
import hmac
|
||||||
|
|
||||||
from flask import session, request
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from flask import session, request
|
||||||
|
|
||||||
from app import app
|
from app import app
|
||||||
from auth.auth_context import get_validated_oauth_token
|
from auth.auth_context import get_validated_oauth_token
|
||||||
|
@ -30,9 +31,10 @@ def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME,
|
||||||
""" Verifies that the CSRF token with the given name is found in the session and
|
""" Verifies that the CSRF token with the given name is found in the session and
|
||||||
that the matching token is found in the request args or values.
|
that the matching token is found in the request args or values.
|
||||||
"""
|
"""
|
||||||
token = session.get(session_token_name, None)
|
token = str(session.get(session_token_name, ''))
|
||||||
found_token = request.values.get(request_token_name, None)
|
found_token = str(request.values.get(request_token_name, ''))
|
||||||
if not token or token != found_token:
|
|
||||||
|
if not token or not found_token or not hmac.compare_digest(token, found_token):
|
||||||
msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s'
|
msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s'
|
||||||
logger.error(msg, session_token_name, token, request_token_name, found_token)
|
logger.error(msg, session_token_name, token, request_token_name, found_token)
|
||||||
abort(403, message='CSRF token was invalid or missing.')
|
abort(403, message='CSRF token was invalid or missing.')
|
||||||
|
|
|
@ -124,9 +124,11 @@ class ApiTestCase(unittest.TestCase):
|
||||||
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||||
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
|
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
|
||||||
|
|
||||||
def url_for(self, resource_name, params={}):
|
def url_for(self, resource_name, params=None, skip_csrf=False):
|
||||||
|
params = params or {}
|
||||||
url = api.url_for(resource_name, **params)
|
url = api.url_for(resource_name, **params)
|
||||||
url = ApiTestCase._add_csrf(url)
|
if not skip_csrf:
|
||||||
|
url = ApiTestCase._add_csrf(url)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -211,8 +213,8 @@ class ApiTestCase(unittest.TestCase):
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
def putJsonResponse(self, resource_name, params={}, data={},
|
def putJsonResponse(self, resource_name, params={}, data={},
|
||||||
expected_code=200):
|
expected_code=200, skip_csrf=False):
|
||||||
rv = self.app.put(self.url_for(resource_name, params),
|
rv = self.app.put(self.url_for(resource_name, params, skip_csrf),
|
||||||
data=py_json.dumps(data),
|
data=py_json.dumps(data),
|
||||||
headers={"Content-Type": "application/json"})
|
headers={"Content-Type": "application/json"})
|
||||||
|
|
||||||
|
@ -246,15 +248,35 @@ class TestCSRFFailure(ApiTestCase):
|
||||||
self.login(READ_ACCESS_USER)
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
# Make sure a simple post call succeeds.
|
# Make sure a simple post call succeeds.
|
||||||
self.putJsonResponse(User,
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'))
|
||||||
data=dict(password='newpasswordiscool'))
|
|
||||||
|
|
||||||
# Change the session's CSRF token.
|
# Change the session's CSRF token.
|
||||||
self.setCsrfToken('someinvalidtoken')
|
self.setCsrfToken('someinvalidtoken')
|
||||||
|
|
||||||
# Verify that the call now fails.
|
# Verify that the call now fails.
|
||||||
self.putJsonResponse(User,
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
|
||||||
data=dict(password='newpasswordiscool'),
|
|
||||||
|
def test_csrf_failure_empty_token(self):
|
||||||
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
|
# Change the session's CSRF token to be empty.
|
||||||
|
self.setCsrfToken('')
|
||||||
|
|
||||||
|
# Verify that the call now fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
|
||||||
|
|
||||||
|
def test_csrf_failure_missing_token(self):
|
||||||
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
|
# Make sure a simple post call without a token at all fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
|
||||||
|
expected_code=403)
|
||||||
|
|
||||||
|
# Change the session's CSRF token to be empty.
|
||||||
|
self.setCsrfToken('')
|
||||||
|
|
||||||
|
# Verify that the call still fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
|
||||||
expected_code=403)
|
expected_code=403)
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in a new issue