Merge pull request #1622 from coreos-inc/rid-fix
Various small fixes around SQL and testing
This commit is contained in:
commit
9f6b47ad1f
4 changed files with 32 additions and 20 deletions
|
@ -1,29 +1,38 @@
|
||||||
def paginate(query, model, descending=False, page_token=None, limit=50, id_field='id'):
|
from peewee import SQL
|
||||||
|
|
||||||
|
def paginate(query, model, descending=False, page_token=None, limit=50, id_alias=None):
|
||||||
""" Paginates the given query using an ID range, starting at the optional page_token.
|
""" Paginates the given query using an ID range, starting at the optional page_token.
|
||||||
Returns a *list* of matching results along with an unencrypted page_token for the
|
Returns a *list* of matching results along with an unencrypted page_token for the
|
||||||
next page, if any. If descending is set to True, orders by the ID descending rather
|
next page, if any. If descending is set to True, orders by the ID descending rather
|
||||||
than ascending.
|
than ascending.
|
||||||
"""
|
"""
|
||||||
query = query.limit(limit + 1)
|
# Note: We use the id_alias for the order_by, but not the where below. The alias is necessary
|
||||||
|
# for certain queries that use unions in MySQL, as it gets confused on which ID to order by.
|
||||||
|
# The where clause, on the other hand, cannot use the alias because Postgres does not allow
|
||||||
|
# aliases in where clauses.
|
||||||
|
id_field = model.id
|
||||||
|
if id_alias is not None:
|
||||||
|
id_field = SQL(id_alias)
|
||||||
|
|
||||||
if descending:
|
if descending:
|
||||||
query = query.order_by(getattr(model, id_field).desc())
|
query = query.order_by(id_field.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(getattr(model, id_field))
|
query = query.order_by(id_field)
|
||||||
|
|
||||||
if page_token is not None:
|
if page_token is not None:
|
||||||
start_id = page_token.get('start_id')
|
start_id = page_token.get('start_id')
|
||||||
if start_id is not None:
|
if start_id is not None:
|
||||||
if descending:
|
if descending:
|
||||||
query = query.where(getattr(model, id_field) <= start_id)
|
query = query.where(model.id <= start_id)
|
||||||
else:
|
else:
|
||||||
query = query.where(getattr(model, id_field) >= start_id)
|
query = query.where(model.id >= start_id)
|
||||||
|
|
||||||
results = list(query)
|
results = list(query)
|
||||||
page_token = None
|
page_token = None
|
||||||
if len(results) > limit:
|
if len(results) > limit:
|
||||||
|
start_id = results[limit].id
|
||||||
page_token = {
|
page_token = {
|
||||||
'start_id': getattr(results[limit], id_field)
|
'start_id': start_id
|
||||||
}
|
}
|
||||||
|
|
||||||
return results[0:limit], page_token
|
return results[0:limit], page_token
|
||||||
|
|
|
@ -9,10 +9,11 @@ from datetime import timedelta, datetime
|
||||||
from flask import request, abort
|
from flask import request, abort
|
||||||
|
|
||||||
from data import model
|
from data import model
|
||||||
|
from data.database import Repository as RepositoryTable
|
||||||
from endpoints.api import (truthy_bool, format_date, nickname, log_action, validate_json_request,
|
from endpoints.api import (truthy_bool, format_date, nickname, log_action, validate_json_request,
|
||||||
require_repo_read, require_repo_write, require_repo_admin,
|
require_repo_read, require_repo_write, require_repo_admin,
|
||||||
RepositoryParamResource, resource, query_param, parse_args, ApiResource,
|
RepositoryParamResource, resource, query_param, parse_args, ApiResource,
|
||||||
request_error, require_scope, path_param, page_support, parse_args,
|
request_error, require_scope, path_param, page_support, parse_args,
|
||||||
query_param, truthy_bool)
|
query_param, truthy_bool)
|
||||||
from endpoints.exception import Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException
|
from endpoints.exception import Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException
|
||||||
from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan
|
from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan
|
||||||
|
@ -165,9 +166,9 @@ class RepositoryList(ApiResource):
|
||||||
# Note: We only limit repositories when there isn't a namespace or starred filter, as they
|
# Note: We only limit repositories when there isn't a namespace or starred filter, as they
|
||||||
# result in far smaller queries.
|
# result in far smaller queries.
|
||||||
if not parsed_args['namespace'] and not parsed_args['starred']:
|
if not parsed_args['namespace'] and not parsed_args['starred']:
|
||||||
repos, next_page_token = model.modelutil.paginate(repo_query, repo_query.c,
|
repos, next_page_token = model.modelutil.paginate(repo_query, RepositoryTable,
|
||||||
page_token=page_token, limit=REPOS_PER_PAGE,
|
page_token=page_token, limit=REPOS_PER_PAGE,
|
||||||
id_field='rid')
|
id_alias='rid')
|
||||||
else:
|
else:
|
||||||
repos = list(repo_query)
|
repos = list(repo_query)
|
||||||
next_page_token = None
|
next_page_token = None
|
||||||
|
|
|
@ -66,8 +66,8 @@ class TestQueue(QueueTestCase):
|
||||||
self.assertEqual(self.reporter.running_count, None)
|
self.assertEqual(self.reporter.running_count, None)
|
||||||
self.assertEqual(self.reporter.total, None)
|
self.assertEqual(self.reporter.total, None)
|
||||||
|
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2, available_after=-1)
|
||||||
self.assertEqual(self.reporter.currently_processing, False)
|
self.assertEqual(self.reporter.currently_processing, False)
|
||||||
self.assertEqual(self.reporter.running_count, 0)
|
self.assertEqual(self.reporter.running_count, 0)
|
||||||
self.assertEqual(self.reporter.total, 1)
|
self.assertEqual(self.reporter.total, 1)
|
||||||
|
@ -97,8 +97,8 @@ class TestQueue(QueueTestCase):
|
||||||
self.assertEqual(self.reporter.total, 1)
|
self.assertEqual(self.reporter.total, 1)
|
||||||
|
|
||||||
def test_different_canonical_names(self):
|
def test_different_canonical_names(self):
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2)
|
self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2, available_after=-1)
|
||||||
self.assertEqual(self.reporter.running_count, 0)
|
self.assertEqual(self.reporter.running_count, 0)
|
||||||
self.assertEqual(self.reporter.total, 2)
|
self.assertEqual(self.reporter.total, 2)
|
||||||
|
|
||||||
|
@ -115,8 +115,8 @@ class TestQueue(QueueTestCase):
|
||||||
self.assertEqual(self.reporter.total, 2)
|
self.assertEqual(self.reporter.total, 2)
|
||||||
|
|
||||||
def test_canonical_name(self):
|
def test_canonical_name(self):
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
|
|
||||||
one = self.queue.get(ordering_required=True)
|
one = self.queue.get(ordering_required=True)
|
||||||
self.assertNotEqual(QUEUE_NAME + '/abc/def/', one)
|
self.assertNotEqual(QUEUE_NAME + '/abc/def/', one)
|
||||||
|
@ -125,7 +125,7 @@ class TestQueue(QueueTestCase):
|
||||||
self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two)
|
self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two)
|
||||||
|
|
||||||
def test_expiration(self):
|
def test_expiration(self):
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
self.assertEqual(self.reporter.running_count, 0)
|
self.assertEqual(self.reporter.running_count, 0)
|
||||||
self.assertEqual(self.reporter.total, 1)
|
self.assertEqual(self.reporter.total, 1)
|
||||||
|
|
||||||
|
@ -148,8 +148,8 @@ class TestQueue(QueueTestCase):
|
||||||
self.assertEqual(self.reporter.total, 1)
|
self.assertEqual(self.reporter.total, 1)
|
||||||
|
|
||||||
def test_specialized_queue(self):
|
def test_specialized_queue(self):
|
||||||
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1)
|
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
|
||||||
self.queue.put(['def', 'def'], self.TEST_MESSAGE_2)
|
self.queue.put(['def', 'def'], self.TEST_MESSAGE_2, available_after=-1)
|
||||||
|
|
||||||
my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def']))
|
my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def']))
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ class TestQueue(QueueTestCase):
|
||||||
|
|
||||||
def test_random_queue_no_duplicates(self):
|
def test_random_queue_no_duplicates(self):
|
||||||
for msg in self.TEST_MESSAGES:
|
for msg in self.TEST_MESSAGES:
|
||||||
self.queue.put(['abc', 'def'], msg)
|
self.queue.put(['abc', 'def'], msg, available_after=-1)
|
||||||
seen = set()
|
seen = set()
|
||||||
|
|
||||||
for _ in range(1, 101):
|
for _ in range(1, 101):
|
||||||
|
|
|
@ -73,3 +73,5 @@ class TestConfig(DefaultConfig):
|
||||||
|
|
||||||
INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid'
|
INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid'
|
||||||
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
|
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
|
||||||
|
|
||||||
|
PROMETHEUS_AGGREGATOR_URL = None
|
||||||
|
|
Reference in a new issue