Merge pull request #1622 from coreos-inc/rid-fix

Various small fixes around SQL and testing
This commit is contained in:
josephschorr 2016-07-15 13:42:02 -04:00 committed by GitHub
commit 9f6b47ad1f
4 changed files with 32 additions and 20 deletions

View file

@ -1,29 +1,38 @@
def paginate(query, model, descending=False, page_token=None, limit=50, id_field='id'): from peewee import SQL
def paginate(query, model, descending=False, page_token=None, limit=50, id_alias=None):
""" Paginates the given query using an ID range, starting at the optional page_token. """ Paginates the given query using an ID range, starting at the optional page_token.
Returns a *list* of matching results along with an unencrypted page_token for the Returns a *list* of matching results along with an unencrypted page_token for the
next page, if any. If descending is set to True, orders by the ID descending rather next page, if any. If descending is set to True, orders by the ID descending rather
than ascending. than ascending.
""" """
query = query.limit(limit + 1) # Note: We use the id_alias for the order_by, but not the where below. The alias is necessary
# for certain queries that use unions in MySQL, as it gets confused on which ID to order by.
# The where clause, on the other hand, cannot use the alias because Postgres does not allow
# aliases in where clauses.
id_field = model.id
if id_alias is not None:
id_field = SQL(id_alias)
if descending: if descending:
query = query.order_by(getattr(model, id_field).desc()) query = query.order_by(id_field.desc())
else: else:
query = query.order_by(getattr(model, id_field)) query = query.order_by(id_field)
if page_token is not None: if page_token is not None:
start_id = page_token.get('start_id') start_id = page_token.get('start_id')
if start_id is not None: if start_id is not None:
if descending: if descending:
query = query.where(getattr(model, id_field) <= start_id) query = query.where(model.id <= start_id)
else: else:
query = query.where(getattr(model, id_field) >= start_id) query = query.where(model.id >= start_id)
results = list(query) results = list(query)
page_token = None page_token = None
if len(results) > limit: if len(results) > limit:
start_id = results[limit].id
page_token = { page_token = {
'start_id': getattr(results[limit], id_field) 'start_id': start_id
} }
return results[0:limit], page_token return results[0:limit], page_token

View file

@ -9,10 +9,11 @@ from datetime import timedelta, datetime
from flask import request, abort from flask import request, abort
from data import model from data import model
from data.database import Repository as RepositoryTable
from endpoints.api import (truthy_bool, format_date, nickname, log_action, validate_json_request, from endpoints.api import (truthy_bool, format_date, nickname, log_action, validate_json_request,
require_repo_read, require_repo_write, require_repo_admin, require_repo_read, require_repo_write, require_repo_admin,
RepositoryParamResource, resource, query_param, parse_args, ApiResource, RepositoryParamResource, resource, query_param, parse_args, ApiResource,
request_error, require_scope, path_param, page_support, parse_args, request_error, require_scope, path_param, page_support, parse_args,
query_param, truthy_bool) query_param, truthy_bool)
from endpoints.exception import Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException from endpoints.exception import Unauthorized, NotFound, InvalidRequest, ExceedsLicenseException
from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan from endpoints.api.billing import lookup_allowed_private_repos, get_namespace_plan
@ -165,9 +166,9 @@ class RepositoryList(ApiResource):
# Note: We only limit repositories when there isn't a namespace or starred filter, as they # Note: We only limit repositories when there isn't a namespace or starred filter, as they
# result in far smaller queries. # result in far smaller queries.
if not parsed_args['namespace'] and not parsed_args['starred']: if not parsed_args['namespace'] and not parsed_args['starred']:
repos, next_page_token = model.modelutil.paginate(repo_query, repo_query.c, repos, next_page_token = model.modelutil.paginate(repo_query, RepositoryTable,
page_token=page_token, limit=REPOS_PER_PAGE, page_token=page_token, limit=REPOS_PER_PAGE,
id_field='rid') id_alias='rid')
else: else:
repos = list(repo_query) repos = list(repo_query)
next_page_token = None next_page_token = None

View file

@ -66,8 +66,8 @@ class TestQueue(QueueTestCase):
self.assertEqual(self.reporter.running_count, None) self.assertEqual(self.reporter.running_count, None)
self.assertEqual(self.reporter.total, None) self.assertEqual(self.reporter.total, None)
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2, available_after=-1)
self.assertEqual(self.reporter.currently_processing, False) self.assertEqual(self.reporter.currently_processing, False)
self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.running_count, 0)
self.assertEqual(self.reporter.total, 1) self.assertEqual(self.reporter.total, 1)
@ -97,8 +97,8 @@ class TestQueue(QueueTestCase):
self.assertEqual(self.reporter.total, 1) self.assertEqual(self.reporter.total, 1)
def test_different_canonical_names(self): def test_different_canonical_names(self):
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2) self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2, available_after=-1)
self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.running_count, 0)
self.assertEqual(self.reporter.total, 2) self.assertEqual(self.reporter.total, 2)
@ -115,8 +115,8 @@ class TestQueue(QueueTestCase):
self.assertEqual(self.reporter.total, 2) self.assertEqual(self.reporter.total, 2)
def test_canonical_name(self): def test_canonical_name(self):
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1, available_after=-1)
one = self.queue.get(ordering_required=True) one = self.queue.get(ordering_required=True)
self.assertNotEqual(QUEUE_NAME + '/abc/def/', one) self.assertNotEqual(QUEUE_NAME + '/abc/def/', one)
@ -125,7 +125,7 @@ class TestQueue(QueueTestCase):
self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two) self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two)
def test_expiration(self): def test_expiration(self):
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.running_count, 0)
self.assertEqual(self.reporter.total, 1) self.assertEqual(self.reporter.total, 1)
@ -148,8 +148,8 @@ class TestQueue(QueueTestCase):
self.assertEqual(self.reporter.total, 1) self.assertEqual(self.reporter.total, 1)
def test_specialized_queue(self): def test_specialized_queue(self):
self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
self.queue.put(['def', 'def'], self.TEST_MESSAGE_2) self.queue.put(['def', 'def'], self.TEST_MESSAGE_2, available_after=-1)
my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def'])) my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def']))
@ -166,7 +166,7 @@ class TestQueue(QueueTestCase):
def test_random_queue_no_duplicates(self): def test_random_queue_no_duplicates(self):
for msg in self.TEST_MESSAGES: for msg in self.TEST_MESSAGES:
self.queue.put(['abc', 'def'], msg) self.queue.put(['abc', 'def'], msg, available_after=-1)
seen = set() seen = set()
for _ in range(1, 101): for _ in range(1, 101):

View file

@ -73,3 +73,5 @@ class TestConfig(DefaultConfig):
INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid' INSTANCE_SERVICE_KEY_KID_LOCATION = 'test/data/test.kid'
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem' INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
PROMETHEUS_AGGREGATOR_URL = None