import base64 import unittest from datetime import datetime, timedelta from tempfile import NamedTemporaryFile from contextlib import contextmanager import jwt import requests from Crypto.PublicKey import RSA from flask import Flask, jsonify, request, make_response from app import app from data.users import ExternalJWTAuthN from initdb import setup_database_for_testing, finished_database_for_testing from test.helpers import liveserver_app _PORT_NUMBER = 5001 @contextmanager def fake_jwt(requires_email=True): """ Context manager which instantiates and runs a webserver with a fake JWT implementation, until the result is yielded. Usage: with fake_jwt() as jwt_auth: # Make jwt_auth requests. """ jwt_app, port, public_key = _create_app(requires_email) server_url = 'http://' + jwt_app.config['SERVER_HOSTNAME'] verify_url = server_url + '/user/verify' query_url = server_url + '/user/query' getuser_url = server_url + '/user/get' jwt_auth = ExternalJWTAuthN(verify_url, query_url, getuser_url, 'authy', '', app.config['HTTPCLIENT'], 300, public_key_path=public_key.name, requires_email=requires_email) with liveserver_app(jwt_app, port): yield jwt_auth def _generate_certs(): public_key = NamedTemporaryFile(delete=True) key = RSA.generate(1024) private_key_data = key.exportKey('PEM') pubkey = key.publickey() public_key.write(pubkey.exportKey('OpenSSH')) public_key.seek(0) return (public_key, private_key_data) def _create_app(emails=True): global _PORT_NUMBER _PORT_NUMBER = _PORT_NUMBER + 1 public_key, private_key_data = _generate_certs() users = [ {'name': 'cool.user', 'email': 'user@domain.com', 'password': 'password'}, {'name': 'some.neat.user', 'email': 'neat@domain.com', 'password': 'foobar'} ] jwt_app = Flask('testjwt') jwt_app.config['SERVER_HOSTNAME'] = 'localhost:%s' % _PORT_NUMBER def _get_basic_auth(): data = base64.b64decode(request.headers['Authorization'][len('Basic '):]) return data.split(':', 1) @jwt_app.route('/user/query', methods=['GET']) def query_users(): query = request.args.get('query') results = [] for user in users: if user['name'].startswith(query): result = { 'username': user['name'], } if emails: result['email'] = user['email'] results.append(result) token_data = { 'iss': 'authy', 'aud': 'quay.io/jwtauthn/query', 'nbf': datetime.utcnow(), 'iat': datetime.utcnow(), 'exp': datetime.utcnow() + timedelta(seconds=60), 'results': results, } encoded = jwt.encode(token_data, private_key_data, 'RS256') return jsonify({ 'token': encoded }) @jwt_app.route('/user/get', methods=['GET']) def get_user(): username = request.args.get('username') if username == 'disabled': return make_response('User is currently disabled', 401) for user in users: if user['name'] == username or user['email'] == username: token_data = { 'iss': 'authy', 'aud': 'quay.io/jwtauthn/getuser', 'nbf': datetime.utcnow(), 'iat': datetime.utcnow(), 'exp': datetime.utcnow() + timedelta(seconds=60), 'sub': user['name'], 'email': user['email'], } encoded = jwt.encode(token_data, private_key_data, 'RS256') return jsonify({ 'token': encoded }) return make_response('Invalid username or password', 404) @jwt_app.route('/user/verify', methods=['GET']) def verify_user(): username, password = _get_basic_auth() if username == 'disabled': return make_response('User is currently disabled', 401) for user in users: if user['name'] == username or user['email'] == username: if password != user['password']: return make_response('', 404) token_data = { 'iss': 'authy', 'aud': 'quay.io/jwtauthn', 'nbf': datetime.utcnow(), 'iat': datetime.utcnow(), 'exp': datetime.utcnow() + timedelta(seconds=60), 'sub': user['name'], 'email': user['email'], } encoded = jwt.encode(token_data, private_key_data, 'RS256') return jsonify({ 'token': encoded }) return make_response('Invalid username or password', 404) jwt_app.config['TESTING'] = True return jwt_app, _PORT_NUMBER, public_key class JWTAuthTestMixin: """ Mixin defining all the JWT auth tests. """ maxDiff = None @property def emails(self): raise NotImplementedError def setUp(self): setup_database_for_testing(self) self.app = app.test_client() self.ctx = app.test_request_context() self.ctx.__enter__() self.session = requests.Session() def tearDown(self): finished_database_for_testing(self) self.ctx.__exit__(True, None, None) def test_verify_and_link_user(self): with fake_jwt(self.emails) as jwt_auth: result, error_message = jwt_auth.verify_and_link_user('invaliduser', 'foobar') self.assertEquals('Invalid username or password', error_message) self.assertIsNone(result) result, _ = jwt_auth.verify_and_link_user('cool.user', 'invalidpassword') self.assertIsNone(result) result, _ = jwt_auth.verify_and_link_user('cool.user', 'password') self.assertIsNotNone(result) self.assertEquals('cool_user', result.username) result, _ = jwt_auth.verify_and_link_user('some.neat.user', 'foobar') self.assertIsNotNone(result) self.assertEquals('some_neat_user', result.username) def test_confirm_existing_user(self): with fake_jwt(self.emails) as jwt_auth: # Create the users in the DB. result, _ = jwt_auth.verify_and_link_user('cool.user', 'password') self.assertIsNotNone(result) result, _ = jwt_auth.verify_and_link_user('some.neat.user', 'foobar') self.assertIsNotNone(result) # Confirm a user with the same internal and external username. result, _ = jwt_auth.confirm_existing_user('cool_user', 'invalidpassword') self.assertIsNone(result) result, _ = jwt_auth.confirm_existing_user('cool_user', 'password') self.assertIsNotNone(result) self.assertEquals('cool_user', result.username) # Fail to confirm the *external* username, which should return nothing. result, _ = jwt_auth.confirm_existing_user('some.neat.user', 'password') self.assertIsNone(result) # Now confirm the internal username. result, _ = jwt_auth.confirm_existing_user('some_neat_user', 'foobar') self.assertIsNotNone(result) self.assertEquals('some_neat_user', result.username) def test_disabled_user_custom_error(self): with fake_jwt(self.emails) as jwt_auth: result, error_message = jwt_auth.verify_and_link_user('disabled', 'password') self.assertIsNone(result) self.assertEquals('User is currently disabled', error_message) def test_query(self): with fake_jwt(self.emails) as jwt_auth: # Lookup `cool`. results, identifier, error_message = jwt_auth.query_users('cool') self.assertIsNone(error_message) self.assertEquals('jwtauthn', identifier) self.assertEquals(1, len(results)) self.assertEquals('cool.user', results[0].username) self.assertEquals('user@domain.com' if self.emails else None, results[0].email) # Lookup `some`. results, identifier, error_message = jwt_auth.query_users('some') self.assertIsNone(error_message) self.assertEquals('jwtauthn', identifier) self.assertEquals(1, len(results)) self.assertEquals('some.neat.user', results[0].username) self.assertEquals('neat@domain.com' if self.emails else None, results[0].email) # Lookup `unknown`. results, identifier, error_message = jwt_auth.query_users('unknown') self.assertIsNone(error_message) self.assertEquals('jwtauthn', identifier) self.assertEquals(0, len(results)) def test_get_user(self): with fake_jwt(self.emails) as jwt_auth: # Lookup cool.user. result, error_message = jwt_auth.get_user('cool.user') self.assertIsNone(error_message) self.assertIsNotNone(result) self.assertEquals('cool.user', result.username) self.assertEquals('user@domain.com', result.email) # Lookup some.neat.user. result, error_message = jwt_auth.get_user('some.neat.user') self.assertIsNone(error_message) self.assertIsNotNone(result) self.assertEquals('some.neat.user', result.username) self.assertEquals('neat@domain.com', result.email) # Lookup unknown user. result, error_message = jwt_auth.get_user('unknownuser') self.assertIsNone(result) def test_link_user(self): with fake_jwt(self.emails) as jwt_auth: # Link cool.user. user, error_message = jwt_auth.link_user('cool.user') self.assertIsNone(error_message) self.assertIsNotNone(user) self.assertEquals('cool_user', user.username) # Link again. Should return the same user record. user_again, _ = jwt_auth.link_user('cool.user') self.assertEquals(user_again.id, user.id) # Confirm cool.user. result, _ = jwt_auth.confirm_existing_user('cool_user', 'password') self.assertIsNotNone(result) self.assertEquals('cool_user', result.username) def test_link_invalid_user(self): with fake_jwt(self.emails) as jwt_auth: user, error_message = jwt_auth.link_user('invaliduser') self.assertIsNotNone(error_message) self.assertIsNone(user) class JWTAuthNoEmailTestCase(JWTAuthTestMixin, unittest.TestCase): """ Test cases for JWT auth, with emails disabled. """ @property def emails(self): return False class JWTAuthTestCase(JWTAuthTestMixin, unittest.TestCase): """ Test cases for JWT auth, with emails enabled. """ @property def emails(self): return True if __name__ == '__main__': unittest.main()