Fix the tests and the one bug that it highlighted.

This commit is contained in:
jakedt 2014-02-16 18:59:24 -05:00
parent b619356907
commit e7064f1191
6 changed files with 53 additions and 24 deletions

View file

@ -33,7 +33,15 @@ class MailConfig(object):
TESTING = False
class SQLiteDB(object):
class RealTransactions(object):
@staticmethod
def create_transaction(db):
return db.transaction()
DB_TRANSACTION_FACTORY = create_transaction
class SQLiteDB(RealTransactions):
DB_NAME = 'test/data/test.db'
DB_CONNECTION_ARGS = {
'threadlocals': True
@ -41,13 +49,27 @@ class SQLiteDB(object):
DB_DRIVER = SqliteDatabase
class FakeTransaction(object):
def __enter__(self):
return self
def __exit__(self, exc_type, value, traceback):
pass
class EphemeralDB(object):
DB_NAME = ':memory:'
DB_CONNECTION_ARGS = {}
DB_DRIVER = SqliteDatabase
@staticmethod
def create_transaction(db):
return FakeTransaction()
class RDSMySQL(object):
DB_TRANSACTION_FACTORY = create_transaction
class RDSMySQL(RealTransactions):
DB_NAME = 'quay'
DB_CONNECTION_ARGS = {
'host': 'fluxmonkeylogin.cb0vumcygprn.us-east-1.rds.amazonaws.com',

View file

@ -13,7 +13,7 @@ from util.names import format_robot_username
logger = logging.getLogger(__name__)
store = app.config['STORAGE']
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
class DataModelException(Exception):
pass
@ -580,6 +580,9 @@ def _visible_repository_query(username=None, include_public=True, limit=None,
def _filter_to_repos_for_user(query, username=None, namespace=None,
include_public=True):
if not include_public and not username:
return Repository.select().where(Repository.id == '-1')
where_clause = None
if username:
UserThroughTeam = User.alias()
@ -872,7 +875,7 @@ def create_repository(namespace, name, creating_user, visibility='private'):
def create_or_link_image(docker_image_id, repository, username, create=True):
with db.transaction():
with transaction_factory(db):
query = (ImageStorage
.select()
.distinct()
@ -934,7 +937,7 @@ def set_image_size(docker_image_id, namespace_name, repository_name,
def set_image_metadata(docker_image_id, namespace_name, repository_name,
created_date_str, comment, command, parent=None):
with db.transaction():
with transaction_factory(db):
query = (Image
.select(Image, ImageStorage)
.join(Repository)
@ -980,7 +983,7 @@ def list_repository_tags(namespace_name, repository_name):
def garbage_collect_repository(namespace_name, repository_name):
with db.transaction():
with transaction_factory(db):
# Get a list of all images used by tags in the repository
tag_query = (RepositoryTag
.select(RepositoryTag, Image, ImageStorage)
@ -1032,7 +1035,7 @@ def garbage_collect_repository(namespace_name, repository_name):
storage.uuid)
store.remove(image_path)
return len(to_remove)
return len(to_remove)
def get_tag_image(namespace_name, repository_name, tag_name):

View file

@ -1,6 +1,10 @@
from datetime import datetime, timedelta
from data.database import QueueItem, db
from app import app
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
class WorkQueue(object):
@ -34,7 +38,7 @@ class WorkQueue(object):
available_or_expired = ((QueueItem.available == True) |
(QueueItem.processing_expires <= now))
with db.transaction():
with transaction_factory(db):
avail = QueueItem.select().where(QueueItem.queue_name == self.queue_name,
QueueItem.available_after <= now,
available_or_expired,

View file

@ -25,7 +25,7 @@ from auth.permissions import (ReadRepositoryPermission,
AdministerOrganizationPermission,
OrganizationMemberPermission,
ViewTeamPermission)
from endpoints.common import common_login
from endpoints.common import common_login, truthy_param
from util.cache import cache_control
from datetime import datetime, timedelta
@ -390,7 +390,7 @@ def get_matching_entities(prefix):
if permission.can():
robot_namespace = namespace_name
if request.args.get('includeTeams', False):
if truthy_param(request.args.get('includeTeams', False)):
teams = model.get_matching_teams(prefix, organization)
except model.InvalidOrganizationException:
@ -984,20 +984,16 @@ def list_repos():
page = request.args.get('page', None)
limit = request.args.get('limit', None)
namespace_filter = request.args.get('namespace', None)
include_public = request.args.get('public', 'true')
include_private = request.args.get('private', 'true')
sort = request.args.get('sort', 'false')
include_count = request.args.get('count', 'false')
include_public = truthy_param(request.args.get('public', True))
include_private = truthy_param(request.args.get('private', True))
sort = truthy_param(request.args.get('sort', False))
include_count = truthy_param(request.args.get('count', False))
try:
limit = int(limit) if limit else None
except TypeError:
limit = None
include_public = include_public == 'true'
include_private = include_private == 'true'
include_count = include_count == 'true'
sort = sort == 'true'
if page:
try:
page = int(page)

View file

@ -14,6 +14,10 @@ from auth.permissions import QuayDeferredPermissionUser
logger = logging.getLogger(__name__)
def truthy_param(param):
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
@login_manager.user_loader
def load_user(username):
logger.debug('Loading user: %s' % username)

View file

@ -661,7 +661,7 @@ class TestCreateRepo(ApiTestCase):
class TestFindRepos(ApiTestCase):
def test_findrepos_asguest(self):
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) == 1
self.assertEquals(len(json['repositories']), 1)
self.assertEquals(json['repositories'][0]['namespace'], 'public')
self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
@ -670,7 +670,7 @@ class TestFindRepos(ApiTestCase):
self.login(NO_ACCESS_USER)
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) == 1
self.assertEquals(len(json['repositories']), 1)
self.assertEquals(json['repositories'][0]['namespace'], 'public')
self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
@ -679,18 +679,18 @@ class TestFindRepos(ApiTestCase):
self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
assert len(json['repositories']) > 1
self.assertGreater(len(json['repositories']), 1)
class TestListRepos(ApiTestCase):
def test_listrepos_asguest(self):
json = self.getJsonResponse('api.list_repos', params=dict(public=True))
assert len(json['repositories']) == 0
self.assertEquals(len(json['repositories']), 1)
def test_listrepos_orgmember(self):
self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.list_repos', params=dict(public=True))
assert len(json['repositories']) > 1
self.assertGreater(len(json['repositories']), 1)
def test_listrepos_filter(self):
self.login(READ_ACCESS_USER)
@ -705,7 +705,7 @@ class TestListRepos(ApiTestCase):
self.login(READ_ACCESS_USER)
json = self.getJsonResponse('api.list_repos', params=dict(limit=2))
assert len(json['repositories']) == 2
self.assertEquals(len(json['repositories']), 2)
class TestUpdateRepo(ApiTestCase):