From 0833c880655334b9d056a9545f3979b16516f363 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Thu, 30 Jan 2014 20:57:40 -0500 Subject: [PATCH 1/3] Make testing much faster by using a save point, rather than recreating the database every test --- application.py | 11 +++++++ data/database.py | 10 ------ initdb.py | 40 ++++++++++++++++++++++- test/test_api_security.py | 15 +++++---- test/test_endpoint_security.py | 17 ++++++---- util/dbutil.py | 60 ++++++++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 24 deletions(-) create mode 100644 util/dbutil.py diff --git a/application.py b/application.py index 26101f50b..63c1e2d8a 100644 --- a/application.py +++ b/application.py @@ -2,6 +2,7 @@ import logging from app import app as application +import model logging.basicConfig(**application.config['LOGGING_CONFIG']) @@ -27,6 +28,16 @@ application.register_blueprint(registry, url_prefix='/v1') application.register_blueprint(api, url_prefix='/api') application.register_blueprint(webhooks, url_prefix='/webhooks') + +def close_db(exc): + db = model.db + if not db.is_closed(): + logger.debug('Disconnecting from database.') + db.close() + +application.teardown_request(close_db) + + # Remove this for prod config application.debug = True diff --git a/data/database.py b/data/database.py index 2afd3f79c..d2319773b 100644 --- a/data/database.py +++ b/data/database.py @@ -12,16 +12,6 @@ logger = logging.getLogger(__name__) db = app.config['DB_DRIVER'](app.config['DB_NAME'], **app.config['DB_CONNECTION_ARGS']) - -def close_db(exc): - if not db.is_closed(): - logger.debug('Disconnecting from database.') - db.close() - - -app.teardown_request(close_db) - - def random_string_generator(length=16): def random_string(): random = SystemRandom() diff --git a/initdb.py b/initdb.py index a3ee26362..55eb76c71 100644 --- a/initdb.py +++ b/initdb.py @@ -29,7 +29,6 @@ SAMPLE_CMDS = [["/bin/bash"], REFERENCE_DATE = datetime(2013, 6, 23) TEST_STRIPE_ID = 'cus_2tmnh3PkXQS8NG' - def __gen_checksum(image_id): h = hashlib.md5(image_id) return 'tarsum+sha256:' + h.hexdigest() + h.hexdigest() @@ -112,6 +111,45 @@ def __generate_repository(user, name, description, is_public, permissions, return repo +from util.dbutil import savepoint_sqlite + +db_initialized_for_testing = False +testcases = {} + +def finished_database_for_testing(testcase): + """ Called when a testcase has finished using the database, indicating that + any changes should be discarded. + """ + global testcases + testcases[testcase]['savepoint'].__exit__(True, None, None) + +def setup_database_for_testing(testcase): + """ Called when a testcase has started using the database, indicating that + the database should be setup (if not already) and a savepoint created. + """ + + # Sanity check to make sure we're not killing our prod db + db = model.db + if (not isinstance(model.db, SqliteDatabase) or + app.config['DB_DRIVER'] is not SqliteDatabase): + raise RuntimeError('Attempted to wipe production database!') + + global db_initialized_for_testing + if not db_initialized_for_testing: + logger.debug('Setting up DB for testing.') + + # Setup the database. + wipe_database() + initialize_database() + populate_database() + + db_initialized_for_testing = True + + # Create a savepoint for the testcase. + global testcases + testcases[testcase] = {} + testcases[testcase]['savepoint'] = savepoint_sqlite(db) + testcases[testcase]['savepoint'].__enter__() def initialize_database(): create_model_tables(all_models) diff --git a/test/test_api_security.py b/test/test_api_security.py index 0c68fcd6f..78cb7d3a7 100644 --- a/test/test_api_security.py +++ b/test/test_api_security.py @@ -2,7 +2,7 @@ import unittest from endpoints.api import api from app import app -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_specs @@ -16,9 +16,10 @@ ADMIN_ACCESS_USER = 'devtable' class ApiTestCase(unittest.TestCase): def setUp(self): - wipe_database() - initialize_database() - populate_database() + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -28,8 +29,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if auth_username: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: sess['user_id'] = auth_username diff --git a/test/test_endpoint_security.py b/test/test_endpoint_security.py index d810b96eb..724629f3a 100644 --- a/test/test_endpoint_security.py +++ b/test/test_endpoint_security.py @@ -2,7 +2,7 @@ import unittest from app import app from util.names import parse_namespace_repository -from initdb import wipe_database, initialize_database, populate_database +from initdb import setup_database_for_testing, finished_database_for_testing from specs import build_index_specs from endpoints.registry import registry from endpoints.index import index @@ -20,10 +20,11 @@ ADMIN_ACCESS_USER = 'devtable' class EndpointTestCase(unittest.TestCase): - def setUp(self): - wipe_database() - initialize_database() - populate_database() + def setUp(self): + setup_database_for_testing(self) + + def tearDown(self): + finished_database_for_testing(self) class _SpecTestBuilder(type): @@ -33,8 +34,10 @@ class _SpecTestBuilder(type): with app.test_client() as c: if session_var_list: # Temporarily remove the teardown functions - teardown_funcs = app.teardown_request_funcs[None] - app.teardown_request_funcs[None] = [] + teardown_funcs = [] + if None in app.teardown_request_funcs: + teardown_funcs = app.teardown_request_funcs[None] + app.teardown_request_funcs[None] = [] with c.session_transaction() as sess: for sess_key, sess_val in session_var_list: diff --git a/util/dbutil.py b/util/dbutil.py new file mode 100644 index 000000000..4e565abca --- /dev/null +++ b/util/dbutil.py @@ -0,0 +1,60 @@ +import uuid + +# Note: These savepoint classes are implemented in peewee, but not the version we have. Copied here from https://github.com/coleifer/peewee/blob/b657d08a14e4cdafee417111ccba62ede9344222/peewee.py + +class savepoint(object): + def __init__(self, db, sid=None): + self.db = db + _compiler = db.compiler() + self.sid = sid or 's' + uuid.uuid4().get_hex() + self.quoted_sid = _compiler.quote(self.sid) + + def _execute(self, query): + self.db.execute_sql(query, require_commit=False) + + def commit(self): + self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) + + def rollback(self): + self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + + def __enter__(self): + self._orig_autocommit = self.db.get_autocommit() + self.db.set_autocommit(False) + self._execute('SAVEPOINT %s;' % self.quoted_sid) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.rollback() + else: + try: + self.commit() + except: + self.rollback() + raise + finally: + self.db.set_autocommit(self._orig_autocommit) + + +class savepoint_sqlite(savepoint): + def __enter__(self): + conn = self.db.get_conn() + # For sqlite, the connection's isolation_level *must* be set to None. + # The act of setting it, though, will break any existing savepoints, + # so only write to it if necessary. + if conn.isolation_level is not None: + self._orig_isolation_level = conn.isolation_level + conn.isolation_level = None + else: + self._orig_isolation_level = None + super(savepoint_sqlite, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + return super(savepoint_sqlite, self).__exit__( + exc_type, exc_val, exc_tb) + finally: + if self._orig_isolation_level is not None: + self.db.get_conn().isolation_level = self._orig_isolation_level From 62deddce24faa2215cf18088f5986aceb7d4a6f9 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Thu, 30 Jan 2014 21:35:39 -0500 Subject: [PATCH 2/3] Upgrade peewee --- initdb.py | 4 ++-- requirements.txt | 2 +- util/dbutil.py | 60 ------------------------------------------------ 3 files changed, 3 insertions(+), 63 deletions(-) delete mode 100644 util/dbutil.py diff --git a/initdb.py b/initdb.py index 55eb76c71..d00c0eeb9 100644 --- a/initdb.py +++ b/initdb.py @@ -4,7 +4,8 @@ import hashlib import random from datetime import datetime, timedelta -from peewee import SqliteDatabase, create_model_tables, drop_model_tables +from peewee import (SqliteDatabase, create_model_tables, drop_model_tables, + savepoint_sqlite) from data.database import * from data import model @@ -111,7 +112,6 @@ def __generate_repository(user, name, description, is_public, permissions, return repo -from util.dbutil import savepoint_sqlite db_initialized_for_testing = False testcases = {} diff --git a/requirements.txt b/requirements.txt index ce9a22c54..98697d7b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ lockfile==0.9.1 marisa-trie==0.5.1 mixpanel-py==3.0.0 paramiko==1.12.0 -peewee==2.1.7 +peewee==2.2.0 py-bcrypt==0.4 pyPdf==1.13 pycrypto==2.6.1 diff --git a/util/dbutil.py b/util/dbutil.py deleted file mode 100644 index 4e565abca..000000000 --- a/util/dbutil.py +++ /dev/null @@ -1,60 +0,0 @@ -import uuid - -# Note: These savepoint classes are implemented in peewee, but not the version we have. Copied here from https://github.com/coleifer/peewee/blob/b657d08a14e4cdafee417111ccba62ede9344222/peewee.py - -class savepoint(object): - def __init__(self, db, sid=None): - self.db = db - _compiler = db.compiler() - self.sid = sid or 's' + uuid.uuid4().get_hex() - self.quoted_sid = _compiler.quote(self.sid) - - def _execute(self, query): - self.db.execute_sql(query, require_commit=False) - - def commit(self): - self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) - - def rollback(self): - self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) - - def __enter__(self): - self._orig_autocommit = self.db.get_autocommit() - self.db.set_autocommit(False) - self._execute('SAVEPOINT %s;' % self.quoted_sid) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - if exc_type: - self.rollback() - else: - try: - self.commit() - except: - self.rollback() - raise - finally: - self.db.set_autocommit(self._orig_autocommit) - - -class savepoint_sqlite(savepoint): - def __enter__(self): - conn = self.db.get_conn() - # For sqlite, the connection's isolation_level *must* be set to None. - # The act of setting it, though, will break any existing savepoints, - # so only write to it if necessary. - if conn.isolation_level is not None: - self._orig_isolation_level = conn.isolation_level - conn.isolation_level = None - else: - self._orig_isolation_level = None - super(savepoint_sqlite, self).__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - return super(savepoint_sqlite, self).__exit__( - exc_type, exc_val, exc_tb) - finally: - if self._orig_isolation_level is not None: - self.db.get_conn().isolation_level = self._orig_isolation_level From acbb075d13d55bd9e1c3c0f345a0e76617458307 Mon Sep 17 00:00:00 2001 From: yackob03 Date: Fri, 31 Jan 2014 11:14:07 -0500 Subject: [PATCH 3/3] Fix the imports for the model db. --- application.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/application.py b/application.py index 63c1e2d8a..0b119f5d2 100644 --- a/application.py +++ b/application.py @@ -1,8 +1,8 @@ import logging from app import app as application +from data.model import db as model_db -import model logging.basicConfig(**application.config['LOGGING_CONFIG']) @@ -30,7 +30,7 @@ application.register_blueprint(webhooks, url_prefix='/webhooks') def close_db(exc): - db = model.db + db = model_db if not db.is_closed(): logger.debug('Disconnecting from database.') db.close()