diff --git a/application.py b/application.py index 26101f50b..63c1e2d8a 100644 --- a/application.py +++ b/application.py @@ -2,6 +2,7 @@ import logging from app import app as application +import model logging.basicConfig(**application.config['LOGGING_CONFIG']) @@ -27,6 +28,16 @@ application.register_blueprint(registry, url_prefix='/v1') application.register_blueprint(api, url_prefix='/api') application.register_blueprint(webhooks, url_prefix='/webhooks') + +def close_db(exc): + db = model.db + if not db.is_closed(): + logger.debug('Disconnecting from database.') + db.close() + +application.teardown_request(close_db) + + # Remove this for prod config application.debug = True diff --git a/data/database.py b/data/database.py index 2afd3f79c..d2319773b 100644 --- a/data/database.py +++ b/data/database.py @@ -12,16 +12,6 @@ logger = logging.getLogger(__name__) db = app.config['DB_DRIVER'](app.config['DB_NAME'], **app.config['DB_CONNECTION_ARGS']) - -def close_db(exc): - if not db.is_closed(): - logger.debug('Disconnecting from database.') - db.close() - - -app.teardown_request(close_db) - - def random_string_generator(length=16): def random_string(): random = SystemRandom() diff --git a/initdb.py b/initdb.py index a3ee26362..55eb76c71 100644 --- a/initdb.py +++ b/initdb.py @@ -29,7 +29,6 @@ SAMPLE_CMDS = [["/bin/bash"], REFERENCE_DATE = datetime(2013, 6, 23) TEST_STRIPE_ID = 'cus_2tmnh3PkXQS8NG' - def __gen_checksum(image_id): h = hashlib.md5(image_id) return 'tarsum+sha256:' + h.hexdigest() + h.hexdigest() @@ -112,6 +111,45 @@ def __generate_repository(user, name, description, is_public, permissions, return repo +from util.dbutil import savepoint_sqlite + +db_initialized_for_testing = False +testcases = {} + +def finished_database_for_testing(testcase): + """ Called when a testcase has finished using the database, indicating that + any changes should be discarded. + """ + global testcases + testcases[testcase]['savepoint'].__exit__(True, None, None) + +def setup_database_for_testing(testcase): + """ Called when a testcase has started using the database, indicating that + the database should be setup (if not already) and a savepoint created. + """ + + # Sanity check to make sure we're not killing our prod db + db = model.db + if (not isinstance(model.db, SqliteDatabase) or + app.config['DB_DRIVER'] is not SqliteDatabase): + raise RuntimeError('Attempted to wipe production database!') + + global db_initialized_for_testing + if not db_initialized_for_testing: + logger.debug('Setting up DB for testing.') + + # Setup the database. + wipe_database() + initialize_database() + populate_database() + + db_initialized_for_testing = True + + # Create a savepoint for the testcase. + global testcases + testcases[testcase] = {} + testcases[testcase]['savepoint'] = savepoint_sqlite(db) + testcases[testcase]['savepoint'].__enter__() def initialize_database(): create_model_tables(all_models) diff --git a/test/test_api_security.py b/test/test_api_security.py index 0c68fcd6f..78cb7d3a7 100644 --- a/test/test_api_security.py +++ b/test/test_api_security.py @@ -2,7 +2,7 @@ import unittest from endpoints.api import api from app import app -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_specs @@ -16,9 +16,10 @@ ADMIN_ACCESS_USER = 'devtable' class ApiTestCase(unittest.TestCase): def setUp(self): - wipe_database() - initialize_database() - populate_database() + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -28,8 +29,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if auth_username: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: sess['user_id'] = auth_username diff --git a/test/test_endpoint_security.py b/test/test_endpoint_security.py index d810b96eb..724629f3a 100644 --- a/test/test_endpoint_security.py +++ b/test/test_endpoint_security.py @@ -2,7 +2,7 @@ import unittest from app import app from util.names import parse_namespace_repository -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_index_specs from endpoints.registry import registry from endpoints.index import index @@ -20,10 +20,11 @@ ADMIN_ACCESS_USER = 'devtable' class EndpointTestCase(unittest.TestCase): - def setUp(self): - wipe_database() - initialize_database() - populate_database() + def setUp(self): + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -33,8 +34,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if session_var_list: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: for sess_key, sess_val in session_var_list: diff --git a/util/dbutil.py b/util/dbutil.py new file mode 100644 index 000000000..4e565abca --- /dev/null +++ b/util/dbutil.py @@ -0,0 +1,60 @@ +import uuid + +# Note: These savepoint classes are implemented in peewee, but not the version we have. Copied here from https://github.com/coleifer/peewee/blob/b657d08a14e4cdafee417111ccba62ede9344222/peewee.py + +class savepoint(object): + def __init__(self, db, sid=None): + self.db = db + _compiler = db.compiler() + self.sid = sid or 's' + uuid.uuid4().get_hex() + self.quoted_sid = _compiler.quote(self.sid) + + def _execute(self, query): + self.db.execute_sql(query, require_commit=False) + + def commit(self): + self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) + + def rollback(self): + self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._orig_autocommit = self.db.get_autocommit() + self.db.set_autocommit(False) + self._execute('SAVEPOINT %s;' % self.quoted_sid) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.rollback() + else: + try: + self.commit() + except: + self.rollback() + raise + finally: + self.db.set_autocommit(self._orig_autocommit) + + +class savepoint_sqlite(savepoint): + def __enter__(self): + conn = self.db.get_conn() + # For sqlite, the connection's isolation_level *must* be set to None. + # The act of setting it, though, will break any existing savepoints, + # so only write to it if necessary. + if conn.isolation_level is not None: + self._orig_isolation_level = conn.isolation_level + conn.isolation_level = None + else: + self._orig_isolation_level = None + super(savepoint_sqlite, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + return super(savepoint_sqlite, self).__exit__( + exc_type, exc_val, exc_tb) + finally: + if self._orig_isolation_level is not None: + self.db.get_conn().isolation_level = self._orig_isolation_level