Fix the tests and the one bug that it highlighted.

This commit is contained in:
jakedt 2014-02-16 18:59:24 -05:00
parent b619356907
commit e7064f1191
6 changed files with 53 additions and 24 deletions

View file

@ -33,7 +33,15 @@ class MailConfig(object):
TESTING = False TESTING = False
class SQLiteDB(object): class RealTransactions(object):
@staticmethod
def create_transaction(db):
return db.transaction()
DB_TRANSACTION_FACTORY = create_transaction
class SQLiteDB(RealTransactions):
DB_NAME = 'test/data/test.db' DB_NAME = 'test/data/test.db'
DB_CONNECTION_ARGS = { DB_CONNECTION_ARGS = {
'threadlocals': True 'threadlocals': True
@ -41,13 +49,27 @@ class SQLiteDB(object):
DB_DRIVER = SqliteDatabase DB_DRIVER = SqliteDatabase
class FakeTransaction(object):
def __enter__(self):
return self
def __exit__(self, exc_type, value, traceback):
pass
class EphemeralDB(object): class EphemeralDB(object):
DB_NAME = ':memory:' DB_NAME = ':memory:'
DB_CONNECTION_ARGS = {} DB_CONNECTION_ARGS = {}
DB_DRIVER = SqliteDatabase DB_DRIVER = SqliteDatabase
@staticmethod
def create_transaction(db):
return FakeTransaction()
class RDSMySQL(object): DB_TRANSACTION_FACTORY = create_transaction
class RDSMySQL(RealTransactions):
DB_NAME = 'quay' DB_NAME = 'quay'
DB_CONNECTION_ARGS = { DB_CONNECTION_ARGS = {
'host': 'fluxmonkeylogin.cb0vumcygprn.us-east-1.rds.amazonaws.com', 'host': 'fluxmonkeylogin.cb0vumcygprn.us-east-1.rds.amazonaws.com',

View file

@ -13,7 +13,7 @@ from util.names import format_robot_username
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
store = app.config['STORAGE'] store = app.config['STORAGE']
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
class DataModelException(Exception): class DataModelException(Exception):
pass pass
@ -580,6 +580,9 @@ def _visible_repository_query(username=None, include_public=True, limit=None,
def _filter_to_repos_for_user(query, username=None, namespace=None, def _filter_to_repos_for_user(query, username=None, namespace=None,
include_public=True): include_public=True):
if not include_public and not username:
return Repository.select().where(Repository.id == '-1')
where_clause = None where_clause = None
if username: if username:
UserThroughTeam = User.alias() UserThroughTeam = User.alias()
@ -872,7 +875,7 @@ def create_repository(namespace, name, creating_user, visibility='private'):
def create_or_link_image(docker_image_id, repository, username, create=True): def create_or_link_image(docker_image_id, repository, username, create=True):
with db.transaction(): with transaction_factory(db):
query = (ImageStorage query = (ImageStorage
.select() .select()
.distinct() .distinct()
@ -934,7 +937,7 @@ def set_image_size(docker_image_id, namespace_name, repository_name,
def set_image_metadata(docker_image_id, namespace_name, repository_name, def set_image_metadata(docker_image_id, namespace_name, repository_name,
created_date_str, comment, command, parent=None): created_date_str, comment, command, parent=None):
with db.transaction(): with transaction_factory(db):
query = (Image query = (Image
.select(Image, ImageStorage) .select(Image, ImageStorage)
.join(Repository) .join(Repository)
@ -980,7 +983,7 @@ def list_repository_tags(namespace_name, repository_name):
def garbage_collect_repository(namespace_name, repository_name): def garbage_collect_repository(namespace_name, repository_name):
with db.transaction(): with transaction_factory(db):
# Get a list of all images used by tags in the repository # Get a list of all images used by tags in the repository
tag_query = (RepositoryTag tag_query = (RepositoryTag
.select(RepositoryTag, Image, ImageStorage) .select(RepositoryTag, Image, ImageStorage)

View file

@ -1,6 +1,10 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from data.database import QueueItem, db from data.database import QueueItem, db
from app import app
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
class WorkQueue(object): class WorkQueue(object):
@ -34,7 +38,7 @@ class WorkQueue(object):
available_or_expired = ((QueueItem.available == True) | available_or_expired = ((QueueItem.available == True) |
(QueueItem.processing_expires <= now)) (QueueItem.processing_expires <= now))
with db.transaction(): with transaction_factory(db):
avail = QueueItem.select().where(QueueItem.queue_name == self.queue_name, avail = QueueItem.select().where(QueueItem.queue_name == self.queue_name,
QueueItem.available_after <= now, QueueItem.available_after <= now,
available_or_expired, available_or_expired,

View file

@ -25,7 +25,7 @@ from auth.permissions import (ReadRepositoryPermission,
AdministerOrganizationPermission, AdministerOrganizationPermission,
OrganizationMemberPermission, OrganizationMemberPermission,
ViewTeamPermission) ViewTeamPermission)
from endpoints.common import common_login from endpoints.common import common_login, truthy_param
from util.cache import cache_control from util.cache import cache_control
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -390,7 +390,7 @@ def get_matching_entities(prefix):
if permission.can(): if permission.can():
robot_namespace = namespace_name robot_namespace = namespace_name
if request.args.get('includeTeams', False): if truthy_param(request.args.get('includeTeams', False)):
teams = model.get_matching_teams(prefix, organization) teams = model.get_matching_teams(prefix, organization)
except model.InvalidOrganizationException: except model.InvalidOrganizationException:
@ -984,20 +984,16 @@ def list_repos():
page = request.args.get('page', None) page = request.args.get('page', None)
limit = request.args.get('limit', None) limit = request.args.get('limit', None)
namespace_filter = request.args.get('namespace', None) namespace_filter = request.args.get('namespace', None)
include_public = request.args.get('public', 'true') include_public = truthy_param(request.args.get('public', True))
include_private = request.args.get('private', 'true') include_private = truthy_param(request.args.get('private', True))
sort = request.args.get('sort', 'false') sort = truthy_param(request.args.get('sort', False))
include_count = request.args.get('count', 'false') include_count = truthy_param(request.args.get('count', False))
try: try:
limit = int(limit) if limit else None limit = int(limit) if limit else None
except TypeError: except TypeError:
limit = None limit = None
include_public = include_public == 'true'
include_private = include_private == 'true'
include_count = include_count == 'true'
sort = sort == 'true'
if page: if page:
try: try:
page = int(page) page = int(page)

View file

@ -14,6 +14,10 @@ from auth.permissions import QuayDeferredPermissionUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def truthy_param(param):
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
@login_manager.user_loader @login_manager.user_loader
def load_user(username): def load_user(username):
logger.debug('Loading user: %s' % username) logger.debug('Loading user: %s' % username)

View file

@ -661,7 +661,7 @@ class TestCreateRepo(ApiTestCase):
class TestFindRepos(ApiTestCase): class TestFindRepos(ApiTestCase):
def test_findrepos_asguest(self): def test_findrepos_asguest(self):
json = self.getJsonResponse('api.find_repos', params=dict(query='p')) json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) == 1 self.assertEquals(len(json['repositories']), 1)
self.assertEquals(json['repositories'][0]['namespace'], 'public') self.assertEquals(json['repositories'][0]['namespace'], 'public')
self.assertEquals(json['repositories'][0]['name'], 'publicrepo') self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
@ -670,7 +670,7 @@ class TestFindRepos(ApiTestCase):
self.login(NO_ACCESS_USER) self.login(NO_ACCESS_USER)
json = self.getJsonResponse('api.find_repos', params=dict(query='p')) json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) == 1 self.assertEquals(len(json['repositories']), 1)
self.assertEquals(json['repositories'][0]['namespace'], 'public') self.assertEquals(json['repositories'][0]['namespace'], 'public')
self.assertEquals(json['repositories'][0]['name'], 'publicrepo') self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
@ -679,18 +679,18 @@ class TestFindRepos(ApiTestCase):
self.login(READ_ACCESS_USER) self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.find_repos', params=dict(query='p')) json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) > 1 self.assertGreater(len(json['repositories']), 1)
class TestListRepos(ApiTestCase): class TestListRepos(ApiTestCase):
def test_listrepos_asguest(self): def test_listrepos_asguest(self):
json = self.getJsonResponse('api.list_repos', params=dict(public=True)) json = self.getJsonResponse('api.list_repos', params=dict(public=True))
assert len(json['repositories']) == 0 self.assertEquals(len(json['repositories']), 1)
def test_listrepos_orgmember(self): def test_listrepos_orgmember(self):
self.login(READ_ACCESS_USER) self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.list_repos', params=dict(public=True)) json = self.getJsonResponse('api.list_repos', params=dict(public=True))
assert len(json['repositories']) > 1 self.assertGreater(len(json['repositories']), 1)
def test_listrepos_filter(self): def test_listrepos_filter(self):
self.login(READ_ACCESS_USER) self.login(READ_ACCESS_USER)
@ -705,7 +705,7 @@ class TestListRepos(ApiTestCase):
self.login(READ_ACCESS_USER) self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.list_repos', params=dict(limit=2)) json = self.getJsonResponse('api.list_repos', params=dict(limit=2))
assert len(json['repositories']) == 2 self.assertEquals(len(json['repositories']), 2)
class TestUpdateRepo(ApiTestCase): class TestUpdateRepo(ApiTestCase):