Fix the tests and the one bug that it highlighted.
This commit is contained in:
parent
b619356907
commit
e7064f1191
6 changed files with 53 additions and 24 deletions
26
config.py
26
config.py
|
@ -33,7 +33,15 @@ class MailConfig(object):
|
|||
TESTING = False
|
||||
|
||||
|
||||
class SQLiteDB(object):
|
||||
class RealTransactions(object):
|
||||
@staticmethod
|
||||
def create_transaction(db):
|
||||
return db.transaction()
|
||||
|
||||
DB_TRANSACTION_FACTORY = create_transaction
|
||||
|
||||
|
||||
class SQLiteDB(RealTransactions):
|
||||
DB_NAME = 'test/data/test.db'
|
||||
DB_CONNECTION_ARGS = {
|
||||
'threadlocals': True
|
||||
|
@ -41,13 +49,27 @@ class SQLiteDB(object):
|
|||
DB_DRIVER = SqliteDatabase
|
||||
|
||||
|
||||
class FakeTransaction(object):
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
class EphemeralDB(object):
|
||||
DB_NAME = ':memory:'
|
||||
DB_CONNECTION_ARGS = {}
|
||||
DB_DRIVER = SqliteDatabase
|
||||
|
||||
@staticmethod
|
||||
def create_transaction(db):
|
||||
return FakeTransaction()
|
||||
|
||||
class RDSMySQL(object):
|
||||
DB_TRANSACTION_FACTORY = create_transaction
|
||||
|
||||
|
||||
class RDSMySQL(RealTransactions):
|
||||
DB_NAME = 'quay'
|
||||
DB_CONNECTION_ARGS = {
|
||||
'host': 'fluxmonkeylogin.cb0vumcygprn.us-east-1.rds.amazonaws.com',
|
||||
|
|
|
@ -13,7 +13,7 @@ from util.names import format_robot_username
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
store = app.config['STORAGE']
|
||||
|
||||
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
|
||||
|
||||
class DataModelException(Exception):
|
||||
pass
|
||||
|
@ -580,6 +580,9 @@ def _visible_repository_query(username=None, include_public=True, limit=None,
|
|||
|
||||
def _filter_to_repos_for_user(query, username=None, namespace=None,
|
||||
include_public=True):
|
||||
if not include_public and not username:
|
||||
return Repository.select().where(Repository.id == '-1')
|
||||
|
||||
where_clause = None
|
||||
if username:
|
||||
UserThroughTeam = User.alias()
|
||||
|
@ -872,7 +875,7 @@ def create_repository(namespace, name, creating_user, visibility='private'):
|
|||
|
||||
|
||||
def create_or_link_image(docker_image_id, repository, username, create=True):
|
||||
with db.transaction():
|
||||
with transaction_factory(db):
|
||||
query = (ImageStorage
|
||||
.select()
|
||||
.distinct()
|
||||
|
@ -934,7 +937,7 @@ def set_image_size(docker_image_id, namespace_name, repository_name,
|
|||
|
||||
def set_image_metadata(docker_image_id, namespace_name, repository_name,
|
||||
created_date_str, comment, command, parent=None):
|
||||
with db.transaction():
|
||||
with transaction_factory(db):
|
||||
query = (Image
|
||||
.select(Image, ImageStorage)
|
||||
.join(Repository)
|
||||
|
@ -980,7 +983,7 @@ def list_repository_tags(namespace_name, repository_name):
|
|||
|
||||
|
||||
def garbage_collect_repository(namespace_name, repository_name):
|
||||
with db.transaction():
|
||||
with transaction_factory(db):
|
||||
# Get a list of all images used by tags in the repository
|
||||
tag_query = (RepositoryTag
|
||||
.select(RepositoryTag, Image, ImageStorage)
|
||||
|
@ -1032,7 +1035,7 @@ def garbage_collect_repository(namespace_name, repository_name):
|
|||
storage.uuid)
|
||||
store.remove(image_path)
|
||||
|
||||
return len(to_remove)
|
||||
return len(to_remove)
|
||||
|
||||
|
||||
def get_tag_image(namespace_name, repository_name, tag_name):
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from data.database import QueueItem, db
|
||||
from app import app
|
||||
|
||||
|
||||
transaction_factory = app.config['DB_TRANSACTION_FACTORY']
|
||||
|
||||
|
||||
class WorkQueue(object):
|
||||
|
@ -34,7 +38,7 @@ class WorkQueue(object):
|
|||
available_or_expired = ((QueueItem.available == True) |
|
||||
(QueueItem.processing_expires <= now))
|
||||
|
||||
with db.transaction():
|
||||
with transaction_factory(db):
|
||||
avail = QueueItem.select().where(QueueItem.queue_name == self.queue_name,
|
||||
QueueItem.available_after <= now,
|
||||
available_or_expired,
|
||||
|
|
|
@ -25,7 +25,7 @@ from auth.permissions import (ReadRepositoryPermission,
|
|||
AdministerOrganizationPermission,
|
||||
OrganizationMemberPermission,
|
||||
ViewTeamPermission)
|
||||
from endpoints.common import common_login
|
||||
from endpoints.common import common_login, truthy_param
|
||||
from util.cache import cache_control
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
@ -390,7 +390,7 @@ def get_matching_entities(prefix):
|
|||
if permission.can():
|
||||
robot_namespace = namespace_name
|
||||
|
||||
if request.args.get('includeTeams', False):
|
||||
if truthy_param(request.args.get('includeTeams', False)):
|
||||
teams = model.get_matching_teams(prefix, organization)
|
||||
|
||||
except model.InvalidOrganizationException:
|
||||
|
@ -984,20 +984,16 @@ def list_repos():
|
|||
page = request.args.get('page', None)
|
||||
limit = request.args.get('limit', None)
|
||||
namespace_filter = request.args.get('namespace', None)
|
||||
include_public = request.args.get('public', 'true')
|
||||
include_private = request.args.get('private', 'true')
|
||||
sort = request.args.get('sort', 'false')
|
||||
include_count = request.args.get('count', 'false')
|
||||
include_public = truthy_param(request.args.get('public', True))
|
||||
include_private = truthy_param(request.args.get('private', True))
|
||||
sort = truthy_param(request.args.get('sort', False))
|
||||
include_count = truthy_param(request.args.get('count', False))
|
||||
|
||||
try:
|
||||
limit = int(limit) if limit else None
|
||||
except TypeError:
|
||||
limit = None
|
||||
|
||||
include_public = include_public == 'true'
|
||||
include_private = include_private == 'true'
|
||||
include_count = include_count == 'true'
|
||||
sort = sort == 'true'
|
||||
if page:
|
||||
try:
|
||||
page = int(page)
|
||||
|
|
|
@ -14,6 +14,10 @@ from auth.permissions import QuayDeferredPermissionUser
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def truthy_param(param):
|
||||
return param not in {False, 'false', 'False', '0', 'FALSE', '', 'null'}
|
||||
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(username):
|
||||
logger.debug('Loading user: %s' % username)
|
||||
|
|
|
@ -661,7 +661,7 @@ class TestCreateRepo(ApiTestCase):
|
|||
class TestFindRepos(ApiTestCase):
|
||||
def test_findrepos_asguest(self):
|
||||
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
|
||||
assert len(json['repositories']) == 1
|
||||
self.assertEquals(len(json['repositories']), 1)
|
||||
|
||||
self.assertEquals(json['repositories'][0]['namespace'], 'public')
|
||||
self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
|
||||
|
@ -670,7 +670,7 @@ class TestFindRepos(ApiTestCase):
|
|||
self.login(NO_ACCESS_USER)
|
||||
|
||||
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
|
||||
assert len(json['repositories']) == 1
|
||||
self.assertEquals(len(json['repositories']), 1)
|
||||
|
||||
self.assertEquals(json['repositories'][0]['namespace'], 'public')
|
||||
self.assertEquals(json['repositories'][0]['name'], 'publicrepo')
|
||||
|
@ -679,18 +679,18 @@ class TestFindRepos(ApiTestCase):
|
|||
self.login(READ_ACCESS_USER)
|
||||
|
||||
json = self.getJsonResponse('api.find_repos', params=dict(query='p'))
|
||||
assert len(json['repositories']) > 1
|
||||
self.assertGreater(len(json['repositories']), 1)
|
||||
|
||||
|
||||
class TestListRepos(ApiTestCase):
|
||||
def test_listrepos_asguest(self):
|
||||
json = self.getJsonResponse('api.list_repos', params=dict(public=True))
|
||||
assert len(json['repositories']) == 0
|
||||
self.assertEquals(len(json['repositories']), 1)
|
||||
|
||||
def test_listrepos_orgmember(self):
|
||||
self.login(READ_ACCESS_USER)
|
||||
json = self.getJsonResponse('api.list_repos', params=dict(public=True))
|
||||
assert len(json['repositories']) > 1
|
||||
self.assertGreater(len(json['repositories']), 1)
|
||||
|
||||
def test_listrepos_filter(self):
|
||||
self.login(READ_ACCESS_USER)
|
||||
|
@ -705,7 +705,7 @@ class TestListRepos(ApiTestCase):
|
|||
self.login(READ_ACCESS_USER)
|
||||
json = self.getJsonResponse('api.list_repos', params=dict(limit=2))
|
||||
|
||||
assert len(json['repositories']) == 2
|
||||
self.assertEquals(len(json['repositories']), 2)
|
||||
|
||||
|
||||
class TestUpdateRepo(ApiTestCase):
|
||||
|
|
Reference in a new issue