diff --git a/data/model/_basequery.py b/data/model/_basequery.py index 0cb0f3fee..34b404c86 100644 --- a/data/model/_basequery.py +++ b/data/model/_basequery.py @@ -34,10 +34,15 @@ def get_public_repo_visibility(): return Visibility.get(name='public') -def filter_to_repos_for_user(query, username=None, namespace=None, include_public=True): +def filter_to_repos_for_user(query, username=None, namespace=None, include_public=True, + start_id=None): if not include_public and not username: return Repository.select().where(Repository.id == '-1') + # Add the start ID if necessary. + if start_id is not None: + query = query.where(Repository.id >= start_id) + # Build a set of queries that, when unioned together, return the full set of visible repositories # for the filters specified. queries = [] diff --git a/data/model/modelutil.py b/data/model/modelutil.py index 6b116ccf0..1a6cdab49 100644 --- a/data/model/modelutil.py +++ b/data/model/modelutil.py @@ -19,16 +19,30 @@ def paginate(query, model, descending=False, page_token=None, limit=50, id_alias else: query = query.order_by(id_field) - if page_token is not None: - start_id = page_token.get('start_id') - if start_id is not None: - if descending: - query = query.where(model.id <= start_id) - else: - query = query.where(model.id >= start_id) + start_id = pagination_start(page_token) + if start_id is not None: + if descending: + query = query.where(model.id <= start_id) + else: + query = query.where(model.id >= start_id) else: query = query.limit(limit + 1) + return paginate_query(query, limit=limit, id_alias=id_alias) + + +def pagination_start(page_token=None): + """ Returns the start ID for pagination for the given page token. Will return None if None. """ + if page_token is not None: + return page_token.get('start_id') + + return None + + +def paginate_query(query, limit=50, id_alias=None): + """ Executes the given query and returns a page's worth of results, as well as the page token + for the next page (if any). + """ results = list(query) page_token = None if len(results) > limit: diff --git a/data/model/repository.py b/data/model/repository.py index 9b3459362..cd37aa0a6 100644 --- a/data/model/repository.py +++ b/data/model/repository.py @@ -2,7 +2,7 @@ import logging import random from datetime import timedelta, datetime -from peewee import JOIN_LEFT_OUTER, fn +from peewee import JOIN_LEFT_OUTER, fn, SQL from cachetools import ttl_cache from data.model import (DataModelException, tag, db_transaction, storage, permission, @@ -245,7 +245,8 @@ def get_when_last_modified(repository_ids): return last_modified_map -def get_visible_repositories(username, namespace=None, include_public=False): +def get_visible_repositories(username, namespace=None, include_public=False, start_id=None, + limit=None): """ Returns the repositories visible to the given user (if any). """ if not include_public and not username: @@ -263,7 +264,12 @@ def get_visible_repositories(username, namespace=None, include_public=False): # Note: We only need the permissions table if we will filter based on a user's permissions. query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN_LEFT_OUTER) - query = _basequery.filter_to_repos_for_user(query, username, namespace, include_public) + query = _basequery.filter_to_repos_for_user(query, username, namespace, include_public, + start_id=start_id) + + if limit is not None: + query = query.limit(limit).order_by(SQL('rid')) + return query diff --git a/endpoints/api/repository.py b/endpoints/api/repository.py index 00f3aa258..b7fb86486 100644 --- a/endpoints/api/repository.py +++ b/endpoints/api/repository.py @@ -149,7 +149,8 @@ class RepositoryList(ApiResource): user = get_authenticated_user() username = user.username if user else None - repo_query = None + next_page_token = None + repos = None # Lookup the requested repositories (either starred or non-starred.) if parsed_args['starred']: @@ -157,24 +158,29 @@ class RepositoryList(ApiResource): # No repositories should be returned, as there is no user. abort(400) - repo_query = model.repository.get_user_starred_repositories(user) + # Return the full list of repos starred by the current user. + repos = list(model.repository.get_user_starred_repositories(user)) + elif parsed_args['namespace']: + # Repositories filtered by namespace do not need pagination (their results are fairly small), + # so we just do the lookup directly. + repos = list(model.repository.get_visible_repositories(username=username, + include_public=parsed_args['public'], + namespace=parsed_args['namespace'])) else: + # Determine the starting offset for pagination. Note that we don't use the normal + # model.modelutil.paginate method here, as that does not operate over UNION queries, which + # get_visible_repositories will return if there is a logged-in user (for performance reasons). + # + # Also note the +1 on the limit, as paginate_query uses the extra result to determine whether + # there is a next page. + start_id = model.modelutil.pagination_start(page_token) repo_query = model.repository.get_visible_repositories(username=username, include_public=parsed_args['public'], - namespace=parsed_args['namespace']) + start_id=start_id, + limit=REPOS_PER_PAGE+1) - # Note: We only limit repositories when there isn't a namespace or starred filter, as they - # result in far smaller queries. - if not parsed_args['namespace'] and not parsed_args['starred']: - # TODO: Fix pagination to support union queries and then remove this hack. - repo_query = model.repository.get_visible_repositories(None, - include_public=parsed_args['public']) - repos, next_page_token = model.modelutil.paginate(repo_query, RepositoryTable, - page_token=page_token, limit=REPOS_PER_PAGE, - id_alias='rid') - else: - repos = list(repo_query) - next_page_token = None + repos, next_page_token = model.modelutil.paginate_query(repo_query, limit=REPOS_PER_PAGE, + id_alias='rid') # Collect the IDs of the repositories found for subequent lookup of popularity # and/or last modified. diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 2955f917e..f7a8dcb16 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -1447,12 +1447,11 @@ class TestCreateRepo(ApiTestCase): class TestListRepos(ApiTestCase): def test_listrepos_asguest(self): # Queries: Base + the list query - # TODO: uncomment once fixed - #with assert_query_count(BASE_QUERY_COUNT + 1): - json = self.getJsonResponse(RepositoryList, params=dict(public=True)) - self.assertEquals(len(json['repositories']), 1) + with assert_query_count(BASE_QUERY_COUNT + 1): + json = self.getJsonResponse(RepositoryList, params=dict(public=True)) + self.assertEquals(len(json['repositories']), 1) - def test_listrepos_asguest_withpages(self): + def assertPublicRepos(self, has_extras=False): public_user = model.user.get_user('public') # Delete all existing repos under the namespace. @@ -1461,7 +1460,7 @@ class TestListRepos(ApiTestCase): # Add public repos until we have enough for a few pages. required = set() - for i in range(0, REPOS_PER_PAGE * 2): + for i in range(0, REPOS_PER_PAGE * 3): name = 'publicrepo%s' % i model.repository.create_repository('public', name, public_user, visibility='public') @@ -1473,8 +1472,10 @@ class TestListRepos(ApiTestCase): json = self.getJsonResponse(RepositoryList, params=dict(public=True, next_page=next_page)) for repo in json['repositories']: name = repo['name'] - self.assertTrue(name in required) - required.remove(name) + if name in required: + required.remove(name) + else: + self.assertTrue(has_extras, "Could not find name %s in repos created" % name) if 'next_page' in json: self.assertEquals(len(json['repositories']), REPOS_PER_PAGE) @@ -1483,13 +1484,12 @@ class TestListRepos(ApiTestCase): next_page = json['next_page'] - # Ensure we found all the repositories. - self.assertEquals(0, len(required)) + def test_listrepos_asguest_withpages(self): + self.assertPublicRepos() - def test_listrepos_asorgmember(self): + def test_listrepos_asorgmember_withpages(self): self.login(READ_ACCESS_USER) - json = self.getJsonResponse(RepositoryList, params=dict(public=True)) - self.assertGreater(len(json['repositories']), 0) + self.assertPublicRepos(has_extras=True) def test_listrepos_filter(self): self.login(READ_ACCESS_USER)