From e7064f1191dd433644534653f1afb66a2fa11339 Mon Sep 17 00:00:00 2001 From: jakedt Date: Sun, 16 Feb 2014 18:59:24 -0500 Subject: [PATCH] Fix the tests and the one bug that it highlighted. --- config.py | 26 ++++++++++++++++++++++++-- data/model.py | 13 ++++++++----- data/queue.py | 6 +++++- endpoints/api.py | 16 ++++++---------- endpoints/common.py | 4 ++++ test/test_api_usage.py | 12 ++++++------ 6 files changed, 53 insertions(+), 24 deletions(-) diff --git a/config.py b/config.py index c81667d4d..073dde963 100644 --- a/config.py +++ b/config.py @@ -33,7 +33,15 @@ class MailConfig(object): TESTING = False -class SQLiteDB(object): +class RealTransactions(object): + @staticmethod + def create_transaction(db): + return db.transaction() + + DB_TRANSACTION_FACTORY = create_transaction + + +class SQLiteDB(RealTransactions): DB_NAME = 'test/data/test.db' DB_CONNECTION_ARGS = { 'threadlocals': True @@ -41,13 +49,27 @@ class SQLiteDB(object): DB_DRIVER = SqliteDatabase +class FakeTransaction(object): + def __enter__(self): + return self + + def __exit__(self, exc_type, value, traceback): + pass + + class EphemeralDB(object): DB_NAME = ':memory:' DB_CONNECTION_ARGS = {} DB_DRIVER = SqliteDatabase + @staticmethod + def create_transaction(db): + return FakeTransaction() -class RDSMySQL(object): + DB_TRANSACTION_FACTORY = create_transaction + + +class RDSMySQL(RealTransactions): DB_NAME = 'quay' DB_CONNECTION_ARGS = { 'host': 'fluxmonkeylogin.cb0vumcygprn.us-east-1.rds.amazonaws.com', diff --git a/data/model.py b/data/model.py index 3d9c1e128..812a367eb 100644 --- a/data/model.py +++ b/data/model.py @@ -13,7 +13,7 @@ from util.names import format_robot_username logger = logging.getLogger(__name__) store = app.config['STORAGE'] - +transaction_factory = app.config['DB_TRANSACTION_FACTORY'] class DataModelException(Exception): pass @@ -580,6 +580,9 @@ def _visible_repository_query(username=None, include_public=True, limit=None, def _filter_to_repos_for_user(query, username=None, namespace=None, include_public=True): + if not include_public and not username: + return Repository.select().where(Repository.id == '-1') + where_clause = None if username: UserThroughTeam = User.alias() @@ -872,7 +875,7 @@ def create_repository(namespace, name, creating_user, visibility='private'): def create_or_link_image(docker_image_id, repository, username, create=True): - with db.transaction(): + with transaction_factory(db): query = (ImageStorage .select() .distinct() @@ -934,7 +937,7 @@ def set_image_size(docker_image_id, namespace_name, repository_name, def set_image_metadata(docker_image_id, namespace_name, repository_name, created_date_str, comment, command, parent=None): - with db.transaction(): + with transaction_factory(db): query = (Image .select(Image, ImageStorage) .join(Repository) @@ -980,7 +983,7 @@ def list_repository_tags(namespace_name, repository_name): def garbage_collect_repository(namespace_name, repository_name): - with db.transaction(): + with transaction_factory(db): # Get a list of all images used by tags in the repository tag_query = (RepositoryTag .select(RepositoryTag, Image, ImageStorage) @@ -1032,7 +1035,7 @@ def garbage_collect_repository(namespace_name, repository_name): storage.uuid) store.remove(image_path) - return len(to_remove) + return len(to_remove) def get_tag_image(namespace_name, repository_name, tag_name): diff --git a/data/queue.py b/data/queue.py index cf0acd898..46db150bf 100644 --- a/data/queue.py +++ b/data/queue.py @@ -1,6 +1,10 @@ from datetime import datetime, timedelta from data.database import QueueItem, db +from app import app + + +transaction_factory = app.config['DB_TRANSACTION_FACTORY'] class WorkQueue(object): @@ -34,7 +38,7 @@ class WorkQueue(object): available_or_expired = ((QueueItem.available == True) | (QueueItem.processing_expires <= now)) - with db.transaction(): + with transaction_factory(db): avail = QueueItem.select().where(QueueItem.queue_name == self.queue_name, QueueItem.available_after <= now, available_or_expired, diff --git a/endpoints/api.py b/endpoints/api.py index 1e694e1b0..4d79c26ca 100644 --- a/endpoints/api.py +++ b/endpoints/api.py @@ -25,7 +25,7 @@ from auth.permissions import (ReadRepositoryPermission, AdministerOrganizationPermission, OrganizationMemberPermission, ViewTeamPermission) -from endpoints.common import common_login +from endpoints.common import common_login, truthy_param from util.cache import cache_control from datetime import datetime, timedelta @@ -390,7 +390,7 @@ def get_matching_entities(prefix): if permission.can(): robot_namespace = namespace_name - if request.args.get('includeTeams', False): + if truthy_param(request.args.get('includeTeams', False)): teams = model.get_matching_teams(prefix, organization) except model.InvalidOrganizationException: @@ -984,20 +984,16 @@ def list_repos(): page = request.args.get('page', None) limit = request.args.get('limit', None) namespace_filter = request.args.get('namespace', None) - include_public = request.args.get('public', 'true') - include_private = request.args.get('private', 'true') - sort = request.args.get('sort', 'false') - include_count = request.args.get('count', 'false') + include_public = truthy_param(request.args.get('public', True)) + include_private = truthy_param(request.args.get('private', True)) + sort = truthy_param(request.args.get('sort', False)) + include_count = truthy_param(request.args.get('count', False)) try: limit = int(limit) if limit else None except TypeError: limit = None - include_public = include_public == 'true' - include_private = include_private == 'true' - include_count = include_count == 'true' - sort = sort == 'true' if page: try: page = int(page) diff --git a/endpoints/common.py b/endpoints/common.py index ec4727edb..0f9e027c8 100644 --- a/endpoints/common.py +++ b/endpoints/common.py @@ -14,6 +14,10 @@ from auth.permissions import QuayDeferredPermissionUser logger = logging.getLogger(__name__) +def truthy_param(param): + return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'} + + @login_manager.user_loader def load_user(username): logger.debug('Loading user: %s' % username) diff --git a/test/test_api_usage.py b/test/test_api_usage.py index efbe5849f..9a46a0615 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -661,7 +661,7 @@ class TestCreateRepo(ApiTestCase): class TestFindRepos(ApiTestCase): def test_findrepos_asguest(self): json = self.getJsonResponse('api.find_repos', params=dict(query='p')) - assert len(json['repositories']) == 1 + self.assertEquals(len(json['repositories']), 1) self.assertEquals(json['repositories'][0]['namespace'], 'public') self.assertEquals(json['repositories'][0]['name'], 'publicrepo') @@ -670,7 +670,7 @@ class TestFindRepos(ApiTestCase): self.login(NO_ACCESS_USER) json = self.getJsonResponse('api.find_repos', params=dict(query='p')) - assert len(json['repositories']) == 1 + self.assertEquals(len(json['repositories']), 1) self.assertEquals(json['repositories'][0]['namespace'], 'public') self.assertEquals(json['repositories'][0]['name'], 'publicrepo') @@ -679,18 +679,18 @@ class TestFindRepos(ApiTestCase): self.login(READ_ACCESS_USER) json = self.getJsonResponse('api.find_repos', params=dict(query='p')) - assert len(json['repositories']) > 1 + self.assertGreater(len(json['repositories']), 1) class TestListRepos(ApiTestCase): def test_listrepos_asguest(self): json = self.getJsonResponse('api.list_repos', params=dict(public=True)) - assert len(json['repositories']) == 0 + self.assertEquals(len(json['repositories']), 1) def test_listrepos_orgmember(self): self.login(READ_ACCESS_USER) json = self.getJsonResponse('api.list_repos', params=dict(public=True)) - assert len(json['repositories']) > 1 + self.assertGreater(len(json['repositories']), 1) def test_listrepos_filter(self): self.login(READ_ACCESS_USER) @@ -705,7 +705,7 @@ class TestListRepos(ApiTestCase): self.login(READ_ACCESS_USER) json = self.getJsonResponse('api.list_repos', params=dict(limit=2)) - assert len(json['repositories']) == 2 + self.assertEquals(len(json['repositories']), 2) class TestUpdateRepo(ApiTestCase):