Upgrade Peewee to latest 3.x

This requires a number of small changes in the data model code, as well as additional testing.
This commit is contained in:
Brad Ison 2018-04-06 13:48:01 -04:00 committed by Joseph Schorr
parent 70b7ee4654
commit d3d9cca182
26 changed files with 220 additions and 193 deletions

View file

@ -17,7 +17,7 @@ import toposort
from enum import Enum
from peewee import *
from playhouse.shortcuts import RetryOperationalError
from peewee import __exception_wrapper__, Function
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase, PooledSqliteDatabase
from sqlalchemy.engine.url import make_url
@ -121,27 +121,27 @@ def delete_instance_filtered(instance, model_class, delete_nullable, skip_transi
# We only want to skip transitive deletes, which are done using subqueries in the form of
# DELETE FROM <table> in <subquery>. If an op is not using a subquery, we allow it to be
# applied directly.
if fk.model_class not in skip_transitive_deletes or query.op != 'in':
if fk.model not in skip_transitive_deletes or query.op.lower() != 'in':
filtered_ops.append((query, fk))
if query.op == 'in':
dependencies[fk.model_class.__name__].add(query.rhs.model_class.__name__)
if query.op.lower() == 'in':
dependencies[fk.model.__name__].add(query.rhs.model.__name__)
elif query.op == '=':
dependencies[fk.model_class.__name__].add(model_class.__name__)
dependencies[fk.model.__name__].add(model_class.__name__)
else:
raise RuntimeError('Unknown operator in recursive repository delete query')
sorted_models = list(reversed(toposort.toposort_flatten(dependencies)))
def sorted_model_key(query_fk_tuple):
cmp_query, cmp_fk = query_fk_tuple
if cmp_query.op == 'in':
if cmp_query.op.lower() == 'in':
return -1
return sorted_models.index(cmp_fk.model_class.__name__)
return sorted_models.index(cmp_fk.model.__name__)
filtered_ops.sort(key=sorted_model_key)
with db_transaction():
for query, fk in filtered_ops:
_model = fk.model_class
_model = fk.model
if fk.null and not delete_nullable:
_model.update(**{fk.name: None}).where(query).execute()
else:
@ -162,6 +162,24 @@ class CallableProxy(Proxy):
return self.obj(*args, **kwargs)
class RetryOperationalError(object):
def execute_sql(self, sql, params=None, commit=True):
try:
cursor = super(RetryOperationalError, self).execute_sql(sql, params, commit)
except OperationalError:
if not self.is_closed():
self.close()
with __exception_wrapper__:
cursor = self.cursor()
cursor.execute(sql, params or ())
if commit and not self.in_transaction():
self.commit()
return cursor
class CloseForLongOperation(object):
""" Helper object which disconnects the database then reconnects after the nested operation
completes.
@ -214,11 +232,11 @@ class TupleSelector(object):
@classmethod
def tuple_reference_key(cls, field):
""" Returns a string key for referencing a field in a TupleSelector. """
if field._node_type == 'func':
if isinstance(field, Function):
return field.name + ','.join([cls.tuple_reference_key(arg) for arg in field.arguments])
if field._node_type == 'field':
return field.name + ':' + field.model_class.__name__
if isinstance(field, Field):
return field.name + ':' + field.model.__name__
raise Exception('Unknown field type %s in TupleSelector' % field._node_type)
@ -268,6 +286,9 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
if parsed_url.password:
db_kwargs['password'] = parsed_url.password
# Remove threadlocals. It used to be required.
db_kwargs.pop('threadlocals', None)
# Note: sqlite does not support connect_timeout.
if parsed_url.drivername != 'sqlite':
db_kwargs['connect_timeout'] = db_kwargs.get('connect_timeout', connect_timeout)
@ -285,8 +306,9 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
db_kwargs.pop('stale_timeout', None)
db_kwargs.pop('max_connections', None)
wrapped_driver = _wrap_for_retry(driver)
return wrapped_driver(parsed_url.database, **db_kwargs)
# wrapped_driver = _wrap_for_retry(driver)
# return wrapped_driver(parsed_url.database, **db_kwargs)
return driver(parsed_url.database, **db_kwargs)
def configure(config_object):
@ -351,20 +373,20 @@ class QuayUserField(ForeignKeyField):
def __init__(self, allows_robots=False, robot_null_delete=False, *args, **kwargs):
self.allows_robots = allows_robots
self.robot_null_delete = robot_null_delete
if 'rel_model' not in kwargs:
kwargs['rel_model'] = User
if 'model' not in kwargs:
kwargs['model'] = User
super(QuayUserField, self).__init__(*args, **kwargs)
class EnumField(ForeignKeyField):
""" Create a cached python Enum from an EnumTable """
def __init__(self, rel_model, enum_key_field='name', *args, **kwargs):
def __init__(self, model, enum_key_field='name', *args, **kwargs):
"""
rel_model is the EnumTable model-class (see ForeignKeyField)
model is the EnumTable model-class (see ForeignKeyField)
enum_key_field is the field from the EnumTable to use as the enum name
"""
self.enum_key_field = enum_key_field
super(EnumField, self).__init__(rel_model, *args, **kwargs)
super(EnumField, self).__init__(model, *args, **kwargs)
@property
@lru_cache(maxsize=1)
@ -412,7 +434,7 @@ class BaseModel(ReadSlaveModel):
if name.endswith('_id'):
field_name = name[0:len(name) - 3]
if field_name in self._meta.fields:
return self._data.get(field_name)
return self.__data__.get(field_name)
return super(BaseModel, self).__getattribute__(name)
@ -449,7 +471,7 @@ class User(BaseModel):
# For all the model dependencies, only delete those that allow robots.
for query, fk in reversed(list(self.dependencies(search_nullable=True))):
if isinstance(fk, QuayUserField) and fk.allows_robots:
_model = fk.model_class
_model = fk.model
if fk.robot_null_delete:
_model.update(**{fk.name: None}).where(query).execute()
@ -551,7 +573,7 @@ class TeamMemberInvite(BaseModel):
user = QuayUserField(index=True, null=True)
email = CharField(null=True)
team = ForeignKeyField(Team)
inviter = ForeignKeyField(User, related_name='inviter')
inviter = ForeignKeyField(User, backref='inviter')
invite_token = CharField(default=urn_generator(['teaminvite']))
@ -664,13 +686,13 @@ class RepositoryPermission(BaseModel):
class PermissionPrototype(BaseModel):
org = QuayUserField(index=True, related_name='orgpermissionproto')
org = QuayUserField(index=True, backref='orgpermissionproto')
uuid = CharField(default=uuid_generator)
activating_user = QuayUserField(allows_robots=True, index=True, null=True,
related_name='userpermissionproto')
delegate_user = QuayUserField(allows_robots=True, related_name='receivingpermission',
backref='userpermissionproto')
delegate_user = QuayUserField(allows_robots=True, backref='receivingpermission',
null=True)
delegate_team = ForeignKeyField(Team, related_name='receivingpermission',
delegate_team = ForeignKeyField(Team, backref='receivingpermission',
null=True)
role = ForeignKeyField(Role)
@ -714,7 +736,7 @@ class RepositoryBuildTrigger(BaseModel):
private_key = TextField(null=True)
config = TextField(default='{}')
write_token = ForeignKeyField(AccessToken, null=True)
pull_robot = QuayUserField(allows_robots=True, null=True, related_name='triggerpullrobot',
pull_robot = QuayUserField(allows_robots=True, null=True, backref='triggerpullrobot',
robot_null_delete=True)
enabled = BooleanField(default=True)
disabled_reason = EnumField(DisableReason, null=True)
@ -789,9 +811,6 @@ class UserRegion(BaseModel):
)
_ImageProxy = Proxy()
class Image(BaseModel):
# This class is intentionally denormalized. Even though images are supposed
# to be globally unique we can't treat them as such for permissions and
@ -816,7 +835,7 @@ class Image(BaseModel):
security_indexed_engine = IntegerField(default=IMAGE_NOT_SCANNED_ENGINE_VERSION, index=True)
# We use a proxy here instead of 'self' in order to disable the foreign key constraint
parent = ForeignKeyField(_ImageProxy, null=True, related_name='children')
parent = DeferredForeignKey('Image', null=True, backref='children')
class Meta:
database = db
@ -835,9 +854,6 @@ class Image(BaseModel):
return map(int, self.ancestors.split('/')[1:-1])
_ImageProxy.initialize(Image)
class DerivedStorageForImage(BaseModel):
source_image = ForeignKeyField(Image)
derivative = ForeignKeyField(ImageStorage)
@ -942,7 +958,7 @@ class RepositoryBuild(BaseModel):
started = DateTimeField(default=datetime.now, index=True)
display_name = CharField()
trigger = ForeignKeyField(RepositoryBuildTrigger, null=True)
pull_robot = QuayUserField(null=True, related_name='buildpullrobot', allows_robots=True,
pull_robot = QuayUserField(null=True, backref='buildpullrobot', allows_robots=True,
robot_null_delete=True)
logs_archived = BooleanField(default=False)
queue_id = CharField(null=True, index=True)
@ -962,9 +978,9 @@ class LogEntryKind(BaseModel):
class LogEntry(BaseModel):
kind = ForeignKeyField(LogEntryKind)
account = IntegerField(index=True, db_column='account_id')
performer = IntegerField(index=True, null=True, db_column='performer_id')
repository = IntegerField(index=True, null=True, db_column='repository_id')
account = IntegerField(index=True, column_name='account_id')
performer = IntegerField(index=True, null=True, column_name='performer_id')
repository = IntegerField(index=True, null=True, column_name='repository_id')
datetime = DateTimeField(default=datetime.now, index=True)
ip = CharField(null=True)
metadata_json = TextField(default='{}')
@ -1024,7 +1040,7 @@ class OAuthApplication(BaseModel):
name = CharField()
description = TextField(default='')
avatar_email = CharField(null=True, db_column='gravatar_email')
avatar_email = CharField(null=True, column_name='gravatar_email')
class OAuthAuthorizationCode(BaseModel):
@ -1163,15 +1179,12 @@ class ServiceKeyApprovalType(Enum):
AUTOMATIC = 'Automatic'
_ServiceKeyApproverProxy = Proxy()
class ServiceKeyApproval(BaseModel):
approver = ForeignKeyField(_ServiceKeyApproverProxy, null=True)
approver = QuayUserField(null=True)
approval_type = CharField(index=True)
approved_date = DateTimeField(default=datetime.utcnow)
notes = TextField(default='')
_ServiceKeyApproverProxy.initialize(User)
class ServiceKey(BaseModel):
name = CharField()
@ -1309,7 +1322,7 @@ class ApprTag(BaseModel):
reverted = BooleanField(default=False)
protected = BooleanField(default=False)
tag_kind = EnumField(ApprTagKind)
linked_tag = ForeignKeyField('self', null=True, related_name='tag_parents')
linked_tag = ForeignKeyField('self', null=True, backref='tag_parents')
class Meta:
database = db

View file

@ -2,7 +2,7 @@ import base64
import resumablehashlib
import json
from peewee import TextField, CharField, Clause
from peewee import TextField, CharField
from data.text import prefix_search

View file

@ -94,9 +94,7 @@ def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='ima
queries = []
if include_public:
queries.append(query
.clone()
.where(Repository.visibility == get_public_repo_visibility()))
queries.append(query.where(Repository.visibility == get_public_repo_visibility()))
if user_id is not None:
AdminTeam = Team.alias()
@ -104,13 +102,11 @@ def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='ima
# Add repositories in which the user has permission.
queries.append(query
.clone()
.switch(RepositoryPermission)
.where(RepositoryPermission.user == user_id))
# Add repositories in which the user is a member of a team that has permission.
queries.append(query
.clone()
.switch(RepositoryPermission)
.join(Team)
.join(TeamMember)
@ -118,7 +114,6 @@ def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='ima
# Add repositories under namespaces in which the user is the org admin.
queries.append(query
.clone()
.switch(Repository)
.join(AdminTeam, on=(Repository.namespace_user == AdminTeam.organization))
.join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))

View file

@ -2,7 +2,7 @@ import json
import os
from datetime import timedelta, datetime
from peewee import JOIN_LEFT_OUTER
from peewee import JOIN
import features
from data.database import (BuildTriggerService, RepositoryBuildTrigger, Repository, Namespace, User,
@ -50,7 +50,7 @@ def get_build_trigger(trigger_uuid):
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryBuildTrigger)
.join(User)
.join(User, on=(RepositoryBuildTrigger.connected_user == User.id))
.where(RepositoryBuildTrigger.uuid == trigger_uuid)
.get())
except RepositoryBuildTrigger.DoesNotExist:
@ -94,10 +94,10 @@ def _get_build_base_query():
.join(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(RepositoryBuild)
.join(User, JOIN_LEFT_OUTER)
.join(User, JOIN.LEFT_OUTER)
.switch(RepositoryBuild)
.join(RepositoryBuildTrigger, JOIN_LEFT_OUTER)
.join(BuildTriggerService, JOIN_LEFT_OUTER)
.join(RepositoryBuildTrigger, JOIN.LEFT_OUTER)
.join(BuildTriggerService, JOIN.LEFT_OUTER)
.order_by(RepositoryBuild.started.desc()))
@ -308,4 +308,3 @@ def update_trigger_disable_status(trigger, final_phase):
else:
# Save the trigger changes.
trigger.save()

View file

@ -6,7 +6,7 @@ from collections import defaultdict
from datetime import datetime
import dateutil.parser
from peewee import JOIN_LEFT_OUTER, IntegrityError, fn
from peewee import JOIN, IntegrityError, fn
from data.model import (DataModelException, db_transaction, _basequery, storage,
InvalidImageException)
@ -273,7 +273,7 @@ def find_create_or_link_image(docker_image_id, repo_obj, username, translations,
.join(ImageStorage)
.switch(Image)
.join(Repository)
.join(RepositoryPermission, JOIN_LEFT_OUTER)
.join(RepositoryPermission, JOIN.LEFT_OUTER)
.switch(Repository)
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.where(ImageStorage.uploading == False,
@ -445,8 +445,8 @@ def get_image_with_storage_and_parent_base():
.select(Image, ImageStorage, Parent, ParentImageStorage)
.join(ImageStorage)
.switch(Image)
.join(Parent, JOIN_LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN_LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage)))
.join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage)))
def set_secscan_status(image, indexed, version):

View file

@ -2,7 +2,7 @@ import json
import logging
from calendar import timegm
from peewee import JOIN_LEFT_OUTER, fn, PeeweeException
from peewee import JOIN, fn, PeeweeException
from datetime import datetime, timedelta
from cachetools import lru_cache
@ -82,11 +82,11 @@ def get_logs_query(start_time, end_time, performer=None, repository=None, namesp
query = _logs_query(selections, start_time, end_time, performer, repository, namespace, ignore,
model=model)
query = (query.switch(model).join(Performer, JOIN_LEFT_OUTER,
query = (query.switch(model).join(Performer, JOIN.LEFT_OUTER,
on=(model.performer == Performer.id).alias('performer')))
if namespace is None and repository is None:
query = (query.switch(model).join(Account, JOIN_LEFT_OUTER,
query = (query.switch(model).join(Account, JOIN.LEFT_OUTER,
on=(model.account == Account.id).alias('account')))
return query

View file

@ -37,7 +37,16 @@ def lookup_notifications_by_path_prefix(prefix):
def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=False,
page=None, limit=None):
base_query = Notification.select().join(NotificationKind)
base_query = (Notification
.select(Notification.id,
Notification.uuid,
Notification.kind,
Notification.metadata_json,
Notification.dismissed,
Notification.lookup_path,
Notification.created.alias('cd'),
Notification.target)
.join(NotificationKind))
if kind_name is not None:
base_query = base_query.where(NotificationKind.name == kind_name)
@ -73,7 +82,8 @@ def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=F
elif limit:
query = query.limit(limit)
return query.order_by(base_query.c.created.desc())
from peewee import SQL
return query.order_by(SQL('cd desc'))
def delete_all_notifications_by_path_prefix(prefix):

View file

@ -1,4 +1,4 @@
from peewee import JOIN_LEFT_OUTER
from peewee import JOIN
from data.database import (RepositoryPermission, User, Repository, Visibility, Role, TeamMember,
PermissionPrototype, Team, TeamRole, Namespace)
@ -112,13 +112,13 @@ def get_prototype_permissions(org):
query = (PermissionPrototype
.select()
.where(PermissionPrototype.org == org)
.join(ActivatingUser, JOIN_LEFT_OUTER,
.join(ActivatingUser, JOIN.LEFT_OUTER,
on=(ActivatingUser.id == PermissionPrototype.activating_user))
.join(DelegateUser, JOIN_LEFT_OUTER,
.join(DelegateUser, JOIN.LEFT_OUTER,
on=(DelegateUser.id == PermissionPrototype.delegate_user))
.join(Team, JOIN_LEFT_OUTER,
.join(Team, JOIN.LEFT_OUTER,
on=(Team.id == PermissionPrototype.delegate_team))
.join(Role, JOIN_LEFT_OUTER, on=(Role.id == PermissionPrototype.role)))
.join(Role, JOIN.LEFT_OUTER, on=(Role.id == PermissionPrototype.role)))
return query

View file

@ -3,8 +3,7 @@ import random
from enum import Enum
from datetime import timedelta, datetime
from peewee import JOIN_LEFT_OUTER, fn, SQL, IntegrityError
from playhouse.shortcuts import case
from peewee import Case, JOIN, fn, SQL, IntegrityError
from cachetools import ttl_cache
from data.model import (
@ -406,7 +405,7 @@ def get_visible_repositories(username, namespace=None, kind_filter='image', incl
user_id = None
if username:
# Note: We only need the permissions table if we will filter based on a user's permissions.
query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN_LEFT_OUTER)
query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN.LEFT_OUTER)
found_namespace = _get_namespace_user(username)
if not found_namespace:
return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
@ -552,7 +551,7 @@ def _get_sorted_matching_repositories(lookup_value, repo_kind='image', include_p
if SEARCH_FIELDS.description.name in search_fields:
clause = Repository.description.match(lookup_value) | clause
cases = [(Repository.name.match(lookup_value), 100 * RepositorySearchScore.score),]
computed_score = case(None, cases, RepositorySearchScore.score).alias('score')
computed_score = Case(None, cases, RepositorySearchScore.score).alias('score')
select_fields.append(computed_score)
query = (Repository.select(*select_fields)

View file

@ -2,7 +2,7 @@ import re
from calendar import timegm
from datetime import datetime, timedelta
from peewee import JOIN_LEFT_OUTER
from peewee import JOIN
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
@ -165,7 +165,7 @@ def approve_service_key(kid, approver, approval_type, notes=''):
def _list_service_keys_query(kid=None, service=None, approved_only=True, alive_only=True,
approval_type=None):
query = ServiceKey.select().join(ServiceKeyApproval, JOIN_LEFT_OUTER)
query = ServiceKey.select().join(ServiceKeyApproval, JOIN.LEFT_OUTER)
if approved_only:
query = query.where(~(ServiceKey.approval >> None))

View file

@ -50,8 +50,7 @@ def gen_sqlalchemy_metadata(peewee_model_list):
alchemy_type = Integer
all_indexes.add(((field.name, ), field.unique))
if not field.deferred:
target_name = '%s.%s' % (field.to_field.model_class._meta.db_table,
field.to_field.db_column)
target_name = '%s.%s' % (field.rel_model._meta.table_name, field.rel_field.column_name)
col_args.append(ForeignKey(target_name))
elif isinstance(field, BigIntegerField):
alchemy_type = BigInteger
@ -74,19 +73,19 @@ def gen_sqlalchemy_metadata(peewee_model_list):
if field.unique or field.index:
all_indexes.add(((field.name, ), field.unique))
new_col = Column(field.db_column, alchemy_type, *col_args, **col_kwargs)
new_col = Column(field.column_name, alchemy_type, *col_args, **col_kwargs)
columns.append(new_col)
new_table = Table(meta.db_table, metadata, *columns)
new_table = Table(meta.table_name, metadata, *columns)
for col_prop_names, unique in all_indexes:
col_names = [meta.fields[prop_name].db_column for prop_name in col_prop_names]
index_name = '%s_%s' % (meta.db_table, '_'.join(col_names))
col_names = [meta.fields[prop_name].column_name for prop_name in col_prop_names]
index_name = '%s_%s' % (meta.table_name, '_'.join(col_names))
col_refs = [getattr(new_table.c, col_name) for col_name in col_names]
Index(index_name, *col_refs, unique=unique)
for col_field_name in fulltext_indexes:
index_name = '%s_%s__fulltext' % (meta.db_table, col_field_name)
index_name = '%s_%s__fulltext' % (meta.table_name, col_field_name)
col_ref = getattr(new_table.c, col_field_name)
Index(index_name, col_ref, postgresql_ops={col_field_name: 'gin_trgm_ops'},
postgresql_using='gin',

View file

@ -4,7 +4,7 @@ import time
from calendar import timegm
from uuid import uuid4
from peewee import IntegrityError, JOIN_LEFT_OUTER, fn
from peewee import IntegrityError, JOIN, fn
from data.model import (image, db_transaction, DataModelException, _basequery,
InvalidManifestException, TagAlreadyCreatedException, StaleTagException,
config)
@ -44,8 +44,8 @@ def get_tags_images_eligible_for_scan(clair_version):
.join(Image, on=(RepositoryTag.image == Image.id))
.join(ImageStorage, on=(Image.storage == ImageStorage.id))
.switch(Image)
.join(Parent, JOIN_LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN_LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage))
.join(Parent, JOIN.LEFT_OUTER, on=(Image.parent == Parent.id))
.join(ParentImageStorage, JOIN.LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage))
.where(RepositoryTag.hidden == False)
.where(Image.security_indexed_engine < clair_version))
@ -71,7 +71,7 @@ def filter_tags_have_repository_event(query, event):
lifetime_start_ts.
"""
query = filter_has_repository_event(query, event)
query = query.switch(Image).join(ImageStorage)
query = query.switch(RepositoryTag).join(Image).join(ImageStorage)
query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
return query
@ -146,12 +146,13 @@ def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=No
# Collect IDs of the tags found for each query.
tags = {}
for query in sharded_queries:
ImageAlias = Image.alias()
tag_query = (_tag_alive(RepositoryTag
.select(*(selections or []))
.distinct()
.join(Image)
.join(ImageAlias)
.where(RepositoryTag.hidden == False)
.where(Image.id << query)
.where(ImageAlias.id << query)
.switch(RepositoryTag)))
if filter_tags is not None:
@ -210,7 +211,7 @@ def list_active_repo_tags(repo):
.join(Image)
.where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
.switch(RepositoryTag)
.join(TagManifest, JOIN_LEFT_OUTER))
.join(TagManifest, JOIN.LEFT_OUTER))
return query

View file

@ -148,12 +148,11 @@ def add_or_invite_to_team(inviter, team, user_obj=None, email=None, requires_inv
def get_matching_user_teams(team_prefix, user_obj, limit=10):
team_prefix_search = prefix_search(Team.name, team_prefix)
query = (Team
.select()
.select(Team.id.distinct(), Team)
.join(User)
.switch(Team)
.join(TeamMember)
.where(TeamMember.user == user_obj, team_prefix_search)
.distinct(Team.id)
.limit(limit))
return query
@ -179,12 +178,11 @@ def get_matching_admined_teams(team_prefix, user_obj, limit=10):
.where(TeamRole.name == 'admin'))
query = (Team
.select()
.select(Team.id.distinct(), Team)
.join(User)
.switch(Team)
.join(TeamMember)
.where(team_prefix_search, Team.organization << (admined_orgs))
.distinct(Team.id)
.limit(limit))
return query
@ -260,8 +258,9 @@ def get_user_teams_within_org(username, organization):
def list_organization_members_by_teams(organization):
query = (TeamMember
.select(Team, User)
.annotate(Team)
.annotate(User)
.join(Team)
.switch(TeamMember)
.join(User)
.where(Team.organization == organization))
return query

View file

@ -1,6 +1,6 @@
import pytest
from peewee import JOIN_LEFT_OUTER
from peewee import JOIN
from playhouse.test_utils import assert_query_count
from data.database import Repository, RepositoryPermission, TeamMember, Namespace
@ -87,7 +87,7 @@ def test_filter_repositories(username, include_public, filter_to_namespace, repo
.distinct()
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository)
.join(RepositoryPermission, JOIN_LEFT_OUTER))
.join(RepositoryPermission, JOIN.LEFT_OUTER))
with assert_query_count(1):
found = list(filter_to_repos_for_user(query, user.id,

View file

@ -8,6 +8,7 @@ from data.model.repository import create_repository, purge_repository, is_empty
from data.model.repository import get_filtered_matching_repositories
from test.fixtures import *
def test_duplicate_repository_different_kinds(initialized_db):
# Create an image repo.
create_repository('devtable', 'somenewrepo', None, repo_kind='image')

View file

@ -1,6 +1,6 @@
import logging
from peewee import JOIN_LEFT_OUTER
from peewee import JOIN
from data.database import (AccessToken, AccessTokenKind, Repository, Namespace, Role,
RepositoryBuildTrigger, LogEntryKind)
@ -38,7 +38,7 @@ def get_repository_delegate_tokens(namespace_name, repository_name):
.switch(AccessToken)
.join(Role)
.switch(AccessToken)
.join(RepositoryBuildTrigger, JOIN_LEFT_OUTER)
.join(RepositoryBuildTrigger, JOIN.LEFT_OUTER)
.where(Repository.name == repository_name, Namespace.username == namespace_name,
AccessToken.temporary == False, RepositoryBuildTrigger.uuid >> None))

View file

@ -4,7 +4,7 @@ import json
import uuid
from flask_login import UserMixin
from peewee import JOIN_LEFT_OUTER, IntegrityError, fn
from peewee import JOIN, IntegrityError, fn
from uuid import uuid4
from datetime import datetime, timedelta
from enum import Enum
@ -397,15 +397,18 @@ def _list_entity_robots(entity_name, include_metadata=True):
""" Return the list of robots for the specified entity. This MUST return a query, not a
materialized list so that callers can use db_for_update.
"""
query = (User
.select(User, FederatedLogin)
.join(FederatedLogin)
.where(User.robot == True, User.username ** (entity_name + '+%')))
if include_metadata:
query = (query.switch(User)
.join(RobotAccountMetadata, JOIN_LEFT_OUTER)
.select(User, FederatedLogin, RobotAccountMetadata))
query = (User
.select(User, FederatedLogin, RobotAccountMetadata)
.join(FederatedLogin)
.switch(User)
.join(RobotAccountMetadata, JOIN.LEFT_OUTER)
.where(User.robot == True, User.username ** (entity_name + '+%')))
else:
query = (User
.select(User, FederatedLogin)
.join(FederatedLogin)
.where(User.robot == True, User.username ** (entity_name + '+%')))
return query
@ -417,12 +420,12 @@ def list_entity_robot_permission_teams(entity_name, limit=None, include_permissi
RobotAccountMetadata.description, RobotAccountMetadata.unstructured_json]
if include_permissions:
query = (query
.join(RepositoryPermission, JOIN_LEFT_OUTER,
.join(RepositoryPermission, JOIN.LEFT_OUTER,
on=(RepositoryPermission.user == FederatedLogin.user))
.join(Repository, JOIN_LEFT_OUTER)
.join(Repository, JOIN.LEFT_OUTER)
.switch(User)
.join(TeamMember, JOIN_LEFT_OUTER)
.join(Team, JOIN_LEFT_OUTER))
.join(TeamMember, JOIN.LEFT_OUTER)
.join(Team, JOIN.LEFT_OUTER))
fields.append(Repository.name)
fields.append(Team.name)
@ -684,7 +687,7 @@ def get_matching_user_namespaces(namespace_prefix, username, limit=10):
.select()
.distinct()
.join(Repository, on=(Repository.namespace_user == Namespace.id))
.join(RepositoryPermission, JOIN_LEFT_OUTER)
.join(RepositoryPermission, JOIN.LEFT_OUTER)
.where(namespace_search))
return _basequery.filter_to_repos_for_user(base_query, namespace_user_id).limit(limit)
@ -710,8 +713,8 @@ def get_matching_users(username_prefix, robot_namespace=None, organization=None,
if organization:
query = (query
.select(User.id, User.username, User.email, User.robot, fn.Sum(Team.id))
.join(TeamMember, JOIN_LEFT_OUTER)
.join(Team, JOIN_LEFT_OUTER, on=((Team.id == TeamMember.team) &
.join(TeamMember, JOIN.LEFT_OUTER)
.join(Team, JOIN.LEFT_OUTER, on=((Team.id == TeamMember.team) &
(Team.organization == organization)))
.order_by(User.robot.desc()))
@ -790,7 +793,7 @@ def verify_user(username_or_email, password):
def get_all_repo_users(namespace_name, repository_name):
return (RepositoryPermission
.select(User.username, User.email, User.robot, Role.name, RepositoryPermission)
.select(User, Role, RepositoryPermission)
.join(User)
.switch(RepositoryPermission)
.join(Role)

View file

@ -1,4 +1,4 @@
from peewee import Clause, SQL, fn, TextField, Field
from peewee import NodeList, SQL, fn, TextField, Field
def _escape_wildcard(search_query):
""" Escapes the wildcards found in the given search query so that they are treated as *characters*
@ -16,7 +16,7 @@ def prefix_search(field, prefix_query):
""" Returns the wildcard match for searching for the given prefix query. """
# Escape the known wildcard characters.
prefix_query = _escape_wildcard(prefix_query)
return Field.__pow__(field, Clause(prefix_query + '%', SQL("ESCAPE '!'")))
return Field.__pow__(field, NodeList((prefix_query + '%', SQL("ESCAPE '!'"))))
def match_mysql(field, search_query):
@ -29,14 +29,13 @@ def match_mysql(field, search_query):
# queries of the form `*` to raise a parsing error. If found, simply filter out.
search_query = search_query.replace('*', '')
return Clause(fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL('%s', search_query)),
parens=True)
return NodeList((fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL('%s', [search_query]))),
parens=True)
def match_like(field, search_query):
""" Generates a full-text match query using an ILIKE operation, which is needed for SQLite and
Postgres.
"""
escaped_query = _escape_wildcard(search_query)
clause = Clause('%' + escaped_query + '%', SQL("ESCAPE '!'"))
clause = NodeList(('%' + escaped_query + '%', SQL("ESCAPE '!'")))
return Field.__pow__(field, clause)