diff --git a/test/registry_tests.py b/test/registry_tests.py index 4300f2df9..f30bb7527 100644 --- a/test/registry_tests.py +++ b/test/registry_tests.py @@ -6,16 +6,19 @@ from flask.blueprints import Blueprint from flask.ext.testing import LiveServerTestCase from app import app +from data.database import close_db_filter, configure from endpoints.v1 import v1_bp from endpoints.api import api_bp from initdb import wipe_database, initialize_database, populate_database from endpoints.csrf import generate_csrf_token +from tempfile import NamedTemporaryFile import endpoints.decorated import json import features import tarfile +import shutil from cStringIO import StringIO from digest.checksums import compute_simple @@ -28,7 +31,9 @@ except ValueError: pass -# Add a test blueprint for generating CSRF tokens and setting feature flags. +# Add a test blueprint for generating CSRF tokens, setting feature flags and reloading the +# DB connection. + testbp = Blueprint('testbp', __name__) @testbp.route('/csrf', methods=['GET']) @@ -42,6 +47,15 @@ def set_feature(feature_name): features._FEATURES[feature_name].value = request.get_json()['value'] return jsonify({'old_value': old_value}) +@testbp.route('/reloaddb', methods=['POST']) +def reload_db(): + # Close any existing connection. + close_db_filter(None) + + # Reload the database config. + configure(app.config) + return 'OK' + app.register_blueprint(testbp, url_prefix='/__test') @@ -69,6 +83,28 @@ class TestFeature(object): headers={'Content-Type': 'application/json'}) _PORT_NUMBER = 5001 +_CLEAN_DATABASE_PATH = None + +def get_new_database_uri(): + # If a clean copy of the database has not yet been created, create one now. + global _CLEAN_DATABASE_PATH + if not _CLEAN_DATABASE_PATH: + wipe_database() + initialize_database() + populate_database() + close_db_filter(None) + + # Save the path of the clean database. + _CLEAN_DATABASE_PATH = app.config['TEST_DB_FILE'].name + + # Create a new temp file to be used as the actual backing database for the test. + # Note that we have the close() the file to ensure we can copy to it via shutil. + local_db_file = NamedTemporaryFile(delete=True) + local_db_file.close() + + # Copy the clean database to the path. + shutil.copy2(_CLEAN_DATABASE_PATH, local_db_file.name) + return 'sqlite:///{0}'.format(local_db_file.name) class RegistryTestCase(LiveServerTestCase): maxDiff = None @@ -76,20 +112,21 @@ class RegistryTestCase(LiveServerTestCase): def create_app(self): global _PORT_NUMBER _PORT_NUMBER = _PORT_NUMBER + 1 + app.config['DEBUG'] = True app.config['TESTING'] = True app.config['LIVESERVER_PORT'] = _PORT_NUMBER + app.config['DB_URI'] = get_new_database_uri() return app def setUp(self): - # Note: We cannot use the normal savepoint-based DB setup here because we are accessing - # different app instances remotely via a live webserver, which is multiprocess. Therefore, we - # completely clear the database between tests. - wipe_database() - initialize_database() - populate_database() - self.clearSession() + # Tell the remote running app to reload the database. By default, the app forks from the + # current context and has already loaded the DB config with the *original* DB URL. We call + # the remote reload method to force it to pick up the changes to DB_URI set in the create_app + # method. + self.conduct('POST', '/__test/reloaddb') + def clearSession(self): self.session = requests.Session() self.signature = None diff --git a/test/testconfig.py b/test/testconfig.py index 75c1c3f9c..2ee3e89bb 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -21,6 +21,7 @@ class TestConfig(DefaultConfig): TESTING = True SECRET_KEY = 'a36c9d7d-25a9-4d3f-a586-3d2f8dc40a83' + TEST_DB_FILE = TEST_DB_FILE DB_URI = os.environ.get('TEST_DATABASE_URI', 'sqlite:///{0}'.format(TEST_DB_FILE.name)) DB_CONNECTION_ARGS = { 'threadlocals': True,