diff --git a/application.py b/application.py index 74847a0c1..7e6139ada 100644 --- a/application.py +++ b/application.py @@ -1,6 +1,7 @@ import logging from app import app as application +from data.model import db as model_db logging.basicConfig(**application.config['LOGGING_CONFIG']) @@ -30,6 +31,16 @@ application.register_blueprint(registry, url_prefix='/v1') application.register_blueprint(api, url_prefix='/api') application.register_blueprint(webhooks, url_prefix='/webhooks') + +def close_db(exc): + db = model_db + if not db.is_closed(): + logger.debug('Disconnecting from database.') + db.close() + +application.teardown_request(close_db) + + # Remove this for prod config application.debug = True diff --git a/data/database.py b/data/database.py index 2afd3f79c..d2319773b 100644 --- a/data/database.py +++ b/data/database.py @@ -12,16 +12,6 @@ logger = logging.getLogger(__name__) db = app.config['DB_DRIVER'](app.config['DB_NAME'], **app.config['DB_CONNECTION_ARGS']) - -def close_db(exc): - if not db.is_closed(): - logger.debug('Disconnecting from database.') - db.close() - - -app.teardown_request(close_db) - - def random_string_generator(length=16): def random_string(): random = SystemRandom() diff --git a/initdb.py b/initdb.py index a3ee26362..d00c0eeb9 100644 --- a/initdb.py +++ b/initdb.py @@ -4,7 +4,8 @@ import hashlib import random from datetime import datetime, timedelta -from peewee import SqliteDatabase, create_model_tables, drop_model_tables +from peewee import (SqliteDatabase, create_model_tables, drop_model_tables, + savepoint_sqlite) from data.database import * from data import model @@ -29,7 +30,6 @@ SAMPLE_CMDS = [["/bin/bash"], REFERENCE_DATE = datetime(2013, 6, 23) TEST_STRIPE_ID = 'cus_2tmnh3PkXQS8NG' - def __gen_checksum(image_id): h = hashlib.md5(image_id) return 'tarsum+sha256:' + h.hexdigest() + h.hexdigest() @@ -113,6 +113,44 @@ def __generate_repository(user, name, description, is_public, permissions, return repo +db_initialized_for_testing = False +testcases = {} + +def finished_database_for_testing(testcase): + """ Called when a testcase has finished using the database, indicating that + any changes should be discarded. + """ + global testcases + testcases[testcase]['savepoint'].__exit__(True, None, None) + +def setup_database_for_testing(testcase): + """ Called when a testcase has started using the database, indicating that + the database should be setup (if not already) and a savepoint created. + """ + + # Sanity check to make sure we're not killing our prod db + db = model.db + if (not isinstance(model.db, SqliteDatabase) or + app.config['DB_DRIVER'] is not SqliteDatabase): + raise RuntimeError('Attempted to wipe production database!') + + global db_initialized_for_testing + if not db_initialized_for_testing: + logger.debug('Setting up DB for testing.') + + # Setup the database. + wipe_database() + initialize_database() + populate_database() + + db_initialized_for_testing = True + + # Create a savepoint for the testcase. + global testcases + testcases[testcase] = {} + testcases[testcase]['savepoint'] = savepoint_sqlite(db) + testcases[testcase]['savepoint'].__enter__() + def initialize_database(): create_model_tables(all_models) diff --git a/requirements.txt b/requirements.txt index ce9a22c54..98697d7b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ lockfile==0.9.1 marisa-trie==0.5.1 mixpanel-py==3.0.0 paramiko==1.12.0 -peewee==2.1.7 +peewee==2.2.0 py-bcrypt==0.4 pyPdf==1.13 pycrypto==2.6.1 diff --git a/test/test_api_security.py b/test/test_api_security.py index 0c68fcd6f..78cb7d3a7 100644 --- a/test/test_api_security.py +++ b/test/test_api_security.py @@ -2,7 +2,7 @@ import unittest from endpoints.api import api from app import app -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_specs @@ -16,9 +16,10 @@ ADMIN_ACCESS_USER = 'devtable' class ApiTestCase(unittest.TestCase): def setUp(self): - wipe_database() - initialize_database() - populate_database() + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -28,8 +29,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if auth_username: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: sess['user_id'] = auth_username diff --git a/test/test_endpoint_security.py b/test/test_endpoint_security.py index d810b96eb..724629f3a 100644 --- a/test/test_endpoint_security.py +++ b/test/test_endpoint_security.py @@ -2,7 +2,7 @@ import unittest from app import app from util.names import parse_namespace_repository -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_index_specs from endpoints.registry import registry from endpoints.index import index @@ -20,10 +20,11 @@ ADMIN_ACCESS_USER = 'devtable' class EndpointTestCase(unittest.TestCase): - def setUp(self): - wipe_database() - initialize_database() - populate_database() + def setUp(self): + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -33,8 +34,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if session_var_list: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: for sess_key, sess_val in session_var_list: