Change repo filtering for users to use a user ID reference, rather than the username

While this means we need an additional query for initial lookup, it makes the *filtering* query (which is the heavy part) require far fewer joins, thus making it more efficient.

Also adds a new unit test to verify that our filter filters to the correct set of repositories.
This commit is contained in:
Joseph Schorr 2018-06-19 10:51:30 -04:00
parent f2b9aa4527
commit 7604e9842b
7 changed files with 158 additions and 34 deletions

View file

@ -69,9 +69,9 @@ def _lookup_team_roles():
return {role.name:role for role in TeamRole.select()} return {role.name:role for role in TeamRole.select()}
def filter_to_repos_for_user(query, username=None, namespace=None, repo_kind='image', def filter_to_repos_for_user(query, user_id=None, namespace=None, repo_kind='image',
include_public=True, start_id=None): include_public=True, start_id=None):
if not include_public and not username: if not include_public and not user_id:
return Repository.select().where(Repository.id == '-1') return Repository.select().where(Repository.id == '-1')
# Filter on the type of repository. # Filter on the type of repository.
@ -85,32 +85,28 @@ def filter_to_repos_for_user(query, username=None, namespace=None, repo_kind='im
if start_id is not None: if start_id is not None:
query = query.where(Repository.id >= start_id) query = query.where(Repository.id >= start_id)
# Add a namespace filter if necessary.
if namespace:
query = query.where(Namespace.username == namespace)
# Build a set of queries that, when unioned together, return the full set of visible repositories # Build a set of queries that, when unioned together, return the full set of visible repositories
# for the filters specified. # for the filters specified.
queries = [] queries = []
where_clause = (True)
if namespace:
where_clause = (Namespace.username == namespace)
if include_public: if include_public:
queries.append(query queries.append(query
.clone() .clone()
.where(Repository.visibility == get_public_repo_visibility(), where_clause)) .where(Repository.visibility == get_public_repo_visibility()))
if username: if user_id is not None:
UserThroughTeam = User.alias()
Org = User.alias()
AdminTeam = Team.alias() AdminTeam = Team.alias()
AdminTeamMember = TeamMember.alias() AdminTeamMember = TeamMember.alias()
AdminUser = User.alias()
# Add repositories in which the user has permission. # Add repositories in which the user has permission.
queries.append(query queries.append(query
.clone() .clone()
.switch(RepositoryPermission) .switch(RepositoryPermission)
.join(User) .where(RepositoryPermission.user == user_id))
.where(User.username == username, where_clause))
# Add repositories in which the user is a member of a team that has permission. # Add repositories in which the user is a member of a team that has permission.
queries.append(query queries.append(query
@ -118,20 +114,16 @@ def filter_to_repos_for_user(query, username=None, namespace=None, repo_kind='im
.switch(RepositoryPermission) .switch(RepositoryPermission)
.join(Team) .join(Team)
.join(TeamMember) .join(TeamMember)
.join(UserThroughTeam, on=(UserThroughTeam.id == TeamMember.user)) .where(TeamMember.user == user_id))
.where(UserThroughTeam.username == username, where_clause))
# Add repositories under namespaces in which the user is the org admin. # Add repositories under namespaces in which the user is the org admin.
queries.append(query queries.append(query
.clone() .clone()
.switch(Repository) .switch(Repository)
.join(Org, on=(Repository.namespace_user == Org.id)) .join(AdminTeam, on=(Repository.namespace_user == AdminTeam.organization))
.join(AdminTeam, on=(Org.id == AdminTeam.organization))
.where(AdminTeam.role == _lookup_team_role('admin'))
.switch(AdminTeam)
.join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team)) .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
.join(AdminUser, on=(AdminTeamMember.user == AdminUser.id)) .where(AdminTeam.role == _lookup_team_role('admin'))
.where(AdminUser.username == username, where_clause)) .where(AdminTeamMember.user == user_id))
return reduce(lambda l, r: l | r, queries) return reduce(lambda l, r: l | r, queries)

View file

@ -12,12 +12,18 @@ from data.model import (DataModelException, db_transaction, _basequery, storage,
InvalidImageException) InvalidImageException)
from data.database import (Image, Repository, ImageStoragePlacement, Namespace, ImageStorage, from data.database import (Image, Repository, ImageStoragePlacement, Namespace, ImageStorage,
ImageStorageLocation, RepositoryPermission, DerivedStorageForImage, ImageStorageLocation, RepositoryPermission, DerivedStorageForImage,
ImageStorageTransformation) ImageStorageTransformation, User)
from util.canonicaljson import canonicalize from util.canonicaljson import canonicalize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _namespace_id_for_username(username):
try:
return User.get(username=username).id
except User.DoesNotExist:
return None
def get_image_with_storage(docker_image_id, storage_uuid): def get_image_with_storage(docker_image_id, storage_uuid):
""" Returns the image with the given docker image ID and storage uuid or None if none. """ Returns the image with the given docker image ID and storage uuid or None if none.
@ -273,7 +279,8 @@ def find_create_or_link_image(docker_image_id, repo_obj, username, translations,
.where(ImageStorage.uploading == False, .where(ImageStorage.uploading == False,
Image.docker_image_id == docker_image_id)) Image.docker_image_id == docker_image_id))
existing_image_query = _basequery.filter_to_repos_for_user(existing_image_query, username) existing_image_query = _basequery.filter_to_repos_for_user(existing_image_query,
_namespace_id_for_username(username))
# If there is an existing image, we try to translate its ancestry and copy its storage. # If there is an existing image, we try to translate its ancestry and copy its storage.
new_image = None new_image = None

View file

@ -403,11 +403,17 @@ def get_visible_repositories(username, namespace=None, kind_filter='image', incl
Namespace.username, Repository.visibility, Repository.kind) Namespace.username, Repository.visibility, Repository.kind)
.switch(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id))) .switch(Repository).join(Namespace, on=(Repository.namespace_user == Namespace.id)))
user_id = None
if username: if username:
# Note: We only need the permissions table if we will filter based on a user's permissions. # Note: We only need the permissions table if we will filter based on a user's permissions.
query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN_LEFT_OUTER) query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN_LEFT_OUTER)
found_namespace = _get_namespace_user(username)
if not found_namespace:
return Repository.select(Repository.id.alias('rid')).where(Repository.id == -1)
query = _basequery.filter_to_repos_for_user(query, username, namespace, kind_filter, user_id = found_namespace.id
query = _basequery.filter_to_repos_for_user(query, user_id, namespace, kind_filter,
include_public, start_id=start_id) include_public, start_id=start_id)
if limit is not None: if limit is not None:
@ -434,6 +440,13 @@ def get_app_search(lookup, search_fields=None, username=None, limit=50):
offset=0, limit=limit) offset=0, limit=limit)
def _get_namespace_user(username):
try:
return User.get(username=username)
except User.DoesNotExist:
return None
def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_kind='image', def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_kind='image',
offset=0, limit=25, search_fields=None): offset=0, limit=25, search_fields=None):
""" Returns an iterator of all repositories matching the given lookup value, with optional """ Returns an iterator of all repositories matching the given lookup value, with optional
@ -451,8 +464,12 @@ def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_
# Add a filter to the iterator, if necessary. # Add a filter to the iterator, if necessary.
if filter_username is not None: if filter_username is not None:
iterator = _filter_repositories_visible_to_username(unfiltered_query, filter_username, limit, filter_user = _get_namespace_user(filter_username)
repo_kind) if filter_user is None:
return []
iterator = _filter_repositories_visible_to_user(unfiltered_query, filter_user.id, limit,
repo_kind)
if offset > 0: if offset > 0:
take(offset, iterator) take(offset, iterator)
@ -462,7 +479,7 @@ def get_filtered_matching_repositories(lookup_value, filter_username=None, repo_
return list(unfiltered_query.offset(offset).limit(limit)) return list(unfiltered_query.offset(offset).limit(limit))
def _filter_repositories_visible_to_username(unfiltered_query, filter_username, limit, repo_kind): def _filter_repositories_visible_to_user(unfiltered_query, filter_user_id, limit, repo_kind):
encountered = set() encountered = set()
chunk_count = limit * 2 chunk_count = limit * 2
unfiltered_page = 0 unfiltered_page = 0
@ -484,11 +501,13 @@ def _filter_repositories_visible_to_username(unfiltered_query, filter_username,
encountered.update(new_unfiltered_ids) encountered.update(new_unfiltered_ids)
# Filter the repositories found to only those visible to the current user. # Filter the repositories found to only those visible to the current user.
query = (Repository.select(Repository, Namespace).distinct() query = (Repository
.select(Repository, Namespace)
.distinct()
.join(Namespace, on=(Namespace.id == Repository.namespace_user)).switch(Repository) .join(Namespace, on=(Namespace.id == Repository.namespace_user)).switch(Repository)
.join(RepositoryPermission).where(Repository.id << list(new_unfiltered_ids))) .join(RepositoryPermission).where(Repository.id << list(new_unfiltered_ids)))
filtered = _basequery.filter_to_repos_for_user(query, filter_username, repo_kind=repo_kind) filtered = _basequery.filter_to_repos_for_user(query, filter_user_id, repo_kind=repo_kind)
# Sort the filtered repositories by their initial order. # Sort the filtered repositories by their initial order.
all_filtered_repos = list(filtered) all_filtered_repos = list(filtered)

View file

@ -0,0 +1,104 @@
import pytest
from peewee import JOIN_LEFT_OUTER
from playhouse.test_utils import assert_query_count
from data.database import Repository, RepositoryPermission, TeamMember, Namespace
from data.model._basequery import filter_to_repos_for_user
from data.model.organization import get_admin_users
from data.model.user import get_namespace_user
from util.names import parse_robot_username
from test.fixtures import *
def _is_team_member(team, user):
return user.id in [member.user_id for member in
TeamMember.select().where(TeamMember.team == team)]
def _get_visible_repositories_for_user(user, repo_kind='image', include_public=False,
namespace=None):
""" Returns all repositories directly visible to the given user, by either repo permission,
or the user being the admin of a namespace.
"""
for repo in Repository.select():
if repo_kind is not None and repo.kind.name != repo_kind:
continue
if namespace is not None and repo.namespace_user.username != namespace:
continue
if include_public and repo.visibility.name == 'public':
yield repo
continue
# Direct repo permission.
try:
RepositoryPermission.get(repository=repo, user=user).get()
yield repo
continue
except RepositoryPermission.DoesNotExist:
pass
# Team permission.
found_in_team = False
for perm in RepositoryPermission.select().where(RepositoryPermission.repository == repo):
if perm.team and _is_team_member(perm.team, user):
found_in_team = True
break
if found_in_team:
yield repo
continue
# Org namespace admin permission.
if user in get_admin_users(repo.namespace_user):
yield repo
continue
@pytest.mark.parametrize('username', [
'devtable',
'devtable+dtrobot',
'public',
'reader',
])
@pytest.mark.parametrize('include_public', [
True,
False
])
@pytest.mark.parametrize('filter_to_namespace', [
True,
False
])
@pytest.mark.parametrize('repo_kind', [
None,
'image',
'application',
])
def test_filter_repositories(username, include_public, filter_to_namespace, repo_kind,
initialized_db):
namespace = username if filter_to_namespace else None
if '+' in username and filter_to_namespace:
namespace, _ = parse_robot_username(username)
user = get_namespace_user(username)
query = (Repository
.select()
.distinct()
.join(Namespace, on=(Repository.namespace_user == Namespace.id))
.switch(Repository)
.join(RepositoryPermission, JOIN_LEFT_OUTER))
with assert_query_count(1):
found = list(filter_to_repos_for_user(query, user.id,
namespace=namespace,
include_public=include_public,
repo_kind=repo_kind))
expected = list(_get_visible_repositories_for_user(user,
repo_kind=repo_kind,
namespace=namespace,
include_public=include_public))
assert len(found) == len(expected)
assert {r.id for r in found} == {r.id for r in expected}

View file

@ -668,6 +668,9 @@ def invalidate_all_sessions(user):
user.save() user.save()
def get_matching_user_namespaces(namespace_prefix, username, limit=10): def get_matching_user_namespaces(namespace_prefix, username, limit=10):
namespace_user = get_namespace_user(username)
namespace_user_id = namespace_user.id if namespace_user is not None else None
namespace_search = prefix_search(Namespace.username, namespace_prefix) namespace_search = prefix_search(Namespace.username, namespace_prefix)
base_query = (Namespace base_query = (Namespace
.select() .select()
@ -676,7 +679,7 @@ def get_matching_user_namespaces(namespace_prefix, username, limit=10):
.join(RepositoryPermission, JOIN_LEFT_OUTER) .join(RepositoryPermission, JOIN_LEFT_OUTER)
.where(namespace_search)) .where(namespace_search))
return _basequery.filter_to_repos_for_user(base_query, username).limit(limit) return _basequery.filter_to_repos_for_user(base_query, namespace_user_id).limit(limit)
def get_matching_users(username_prefix, robot_namespace=None, organization=None, limit=20, def get_matching_users(username_prefix, robot_namespace=None, organization=None, limit=20,
exact_matches_only=False): exact_matches_only=False):

View file

@ -2,7 +2,6 @@ import pytest
from playhouse.test_utils import assert_query_count from playhouse.test_utils import assert_query_count
from data.model import _basequery
from endpoints.api.search import ConductRepositorySearch, ConductSearch from endpoints.api.search import ConductRepositorySearch, ConductSearch
from endpoints.api.test.shared import conduct_api_call from endpoints.api.test.shared import conduct_api_call
from endpoints.test.shared import client_with_identity from endpoints.test.shared import client_with_identity
@ -17,7 +16,7 @@ from test.fixtures import *
def test_repository_search(query, client): def test_repository_search(query, client):
with client_with_identity('devtable', client) as cl: with client_with_identity('devtable', client) as cl:
params = {'query': query} params = {'query': query}
with assert_query_count(6): with assert_query_count(7):
result = conduct_api_call(cl, ConductRepositorySearch, 'GET', params, None, 200).json result = conduct_api_call(cl, ConductRepositorySearch, 'GET', params, None, 200).json
assert result['start_index'] == 0 assert result['start_index'] == 0
assert result['page'] == 1 assert result['page'] == 1
@ -32,6 +31,6 @@ def test_repository_search(query, client):
def test_search_query_count(query, client): def test_search_query_count(query, client):
with client_with_identity('devtable', client) as cl: with client_with_identity('devtable', client) as cl:
params = {'query': query} params = {'query': query}
with assert_query_count(8): with assert_query_count(10):
result = conduct_api_call(cl, ConductSearch, 'GET', params, None, 200).json result = conduct_api_call(cl, ConductSearch, 'GET', params, None, 200).json
assert len(result['results']) assert len(result['results'])

View file

@ -1858,7 +1858,7 @@ class TestListRepos(ApiTestCase):
self.login(ADMIN_ACCESS_USER) self.login(ADMIN_ACCESS_USER)
# Queries: Base + the list query + the popularity and last modified queries + full perms load # Queries: Base + the list query + the popularity and last modified queries + full perms load
with assert_query_count(BASE_LOGGEDIN_QUERY_COUNT + 4): with assert_query_count(BASE_LOGGEDIN_QUERY_COUNT + 5):
json = self.getJsonResponse(RepositoryList, params=dict(namespace=ORGANIZATION, public=False, json = self.getJsonResponse(RepositoryList, params=dict(namespace=ORGANIZATION, public=False,
last_modified=True, popularity=True)) last_modified=True, popularity=True))