68 lines
		
	
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			68 lines
		
	
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | |
| import json
 | |
| import base64
 | |
| 
 | |
| from contextlib import contextmanager
 | |
| from data import model
 | |
| 
 | |
| from flask import g
 | |
| from flask_principal import Identity
 | |
| 
 | |
| CSRF_TOKEN_KEY = '_csrf_token'
 | |
| CSRF_TOKEN = '123csrfforme'
 | |
| 
 | |
| @contextmanager
 | |
| def client_with_identity(auth_username, client):
 | |
|   with client.session_transaction() as sess:
 | |
|     if auth_username and auth_username is not None:
 | |
|       loaded = model.user.get_user(auth_username)
 | |
|       sess['user_id'] = loaded.uuid
 | |
|       sess['login_time'] = datetime.datetime.now()
 | |
|       sess[CSRF_TOKEN_KEY] = CSRF_TOKEN
 | |
|     else:
 | |
|       sess['user_id'] = 'anonymous'
 | |
| 
 | |
|   yield client
 | |
| 
 | |
|   with client.session_transaction() as sess:
 | |
|     sess['user_id'] = None
 | |
|     sess['login_time'] = None
 | |
|     sess[CSRF_TOKEN_KEY] = None
 | |
| 
 | |
| 
 | |
| def add_csrf_param(params):
 | |
|   """ Returns a params dict with the CSRF parameter added. """
 | |
|   params = params or {}
 | |
| 
 | |
|   if not CSRF_TOKEN_KEY in params:
 | |
|     params[CSRF_TOKEN_KEY] = CSRF_TOKEN
 | |
| 
 | |
|   return params
 | |
| 
 | |
| 
 | |
| def gen_basic_auth(username, password):
 | |
|   """ Generates a basic auth header. """
 | |
|   return 'Basic ' + base64.b64encode("%s:%s" % (username, password))
 | |
| 
 | |
| 
 | |
| def conduct_call(client, resource, url_for, method, params, body=None, expected_code=200,
 | |
|                  headers=None):
 | |
|   """ Conducts a call to a Flask endpoint. """
 | |
|   params = add_csrf_param(params)
 | |
| 
 | |
|   final_url = url_for(resource, **params)
 | |
| 
 | |
|   headers = headers or {}
 | |
|   headers.update({"Content-Type": "application/json"})
 | |
| 
 | |
|   if body is not None:
 | |
|     body = json.dumps(body)
 | |
| 
 | |
|   # Required for anonymous calls to not exception.
 | |
|   g.identity = Identity(None, 'none')
 | |
| 
 | |
|   rv = client.open(final_url, method=method, data=body, headers=headers)
 | |
|   msg = '%s %s: got %s expected: %s | %s' % (method, final_url, rv.status_code, expected_code,
 | |
|                                              rv.data)
 | |
|   assert rv.status_code == expected_code, msg
 | |
|   return rv
 |