diff --git a/data/model/modelutil.py b/data/model/modelutil.py index af1e6d123..bbea5b5fe 100644 --- a/data/model/modelutil.py +++ b/data/model/modelutil.py @@ -1,29 +1,38 @@ -def paginate(query, model, descending=False, page_token=None, limit=50, id_field='id'): +from peewee import SQL + +def paginate(query, model, descending=False, page_token=None, limit=50, id_alias=None): """ Paginates the given query using an ID range, starting at the optional page_token. Returns a *list* of matching results along with an unencrypted page_token for the next page, if any. If descending is set to True, orders by the ID descending rather than ascending. """ - query = query.limit(limit + 1) + # Note: We use the id_alias for the order_by, but not the where below. The alias is necessary + # for certain queries that use unions in MySQL, as it gets confused on which ID to order by. + # The where clause, on the other hand, cannot use the alias because Postgres does not allow + # aliases in where clauses. + id_field = model.id + if id_alias is not None: + id_field = SQL(id_alias) if descending: - query = query.order_by(getattr(model, id_field).desc()) + query = query.order_by(id_field.desc()) else: - query = query.order_by(getattr(model, id_field)) + query = query.order_by(id_field) if page_token is not None: start_id = page_token.get('start_id') if start_id is not None: if descending: - query = query.where(getattr(model, id_field) <= start_id) + query = query.where(model.id <= start_id) else: - query = query.where(getattr(model, id_field) >= start_id) + query = query.where(model.id >= start_id) results = list(query) page_token = None if len(results) > limit: + start_id = results[limit].id page_token = { - 'start_id': getattr(results[limit], id_field) + 'start_id': start_id } return results[0:limit], page_token diff --git a/endpoints/api/repository.py b/endpoints/api/repository.py index aab57ea66..6950402fe 100644 --- a/endpoints/api/repository.py +++ b/endpoints/api/repository.py @@ -9,10 +9,11 @@ from datetime import timedelta, datetime from flask import request, abort from data import model +from data.database import Repository as RepositoryTable from endpoints.api import (truthy_bool, format_date, nickname, log_action, validate_json_request, require_repo_read, require_repo_write, require_repo_admin, RepositoryParamResource, resource, query_param, parse_args, ApiResource, - request_error, require_scope, path_param, page_support, parse_args, + request_error, require_scope, path_param, page_support, parse_args, query_param, truthy_bool) from endpoints.exception import Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan @@ -165,9 +166,9 @@ class RepositoryList(ApiResource): # Note: We only limit repositories when there isn't a namespace or starred filter, as they # result in far smaller queries. if not parsed_args['namespace'] and not parsed_args['starred']: - repos, next_page_token = model.modelutil.paginate(repo_query, repo_query.c, + repos, next_page_token = model.modelutil.paginate(repo_query, RepositoryTable, page_token=page_token, limit=REPOS_PER_PAGE, - id_field='rid') + id_alias='rid') else: repos = list(repo_query) next_page_token = None diff --git a/test/test_queue.py b/test/test_queue.py index 4703ba8b9..d31268850 100644 --- a/test/test_queue.py +++ b/test/test_queue.py @@ -66,8 +66,8 @@ class TestQueue(QueueTestCase): self.assertEqual(self.reporter.running_count, None) self.assertEqual(self.reporter.total, None) - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2, available_after=-1) self.assertEqual(self.reporter.currently_processing, False) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) @@ -97,8 +97,8 @@ class TestQueue(QueueTestCase): self.assertEqual(self.reporter.total, 1) def test_different_canonical_names(self): - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) - self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1) + self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2, available_after=-1) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 2) @@ -115,8 +115,8 @@ class TestQueue(QueueTestCase): self.assertEqual(self.reporter.total, 2) def test_canonical_name(self): - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) - self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1) + self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1, available_after=-1) one = self.queue.get(ordering_required=True) self.assertNotEqual(QUEUE_NAME + '/abc/def/', one) @@ -125,7 +125,7 @@ class TestQueue(QueueTestCase): self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two) def test_expiration(self): - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) @@ -148,8 +148,8 @@ class TestQueue(QueueTestCase): self.assertEqual(self.reporter.total, 1) def test_specialized_queue(self): - self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) - self.queue.put(['def', 'def'], self.TEST_MESSAGE_2) + self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1) + self.queue.put(['def', 'def'], self.TEST_MESSAGE_2, available_after=-1) my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def'])) @@ -166,7 +166,7 @@ class TestQueue(QueueTestCase): def test_random_queue_no_duplicates(self): for msg in self.TEST_MESSAGES: - self.queue.put(['abc', 'def'], msg) + self.queue.put(['abc', 'def'], msg, available_after=-1) seen = set() for _ in range(1, 101): diff --git a/test/testconfig.py b/test/testconfig.py index 8e37113d6..aee1b902f 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -73,3 +73,5 @@ class TestConfig(DefaultConfig): INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid' INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem' + + PROMETHEUS_AGGREGATOR_URL = None