From 29058201e5de32320cc588184a679419ca9f55e7 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Mon, 14 May 2018 11:40:31 -0400 Subject: [PATCH 1/2] Fix bug in modelutil pagination that caused us to load far more results than necessary Also adds tests for the modelutil pagination --- data/model/modelutil.py | 3 +- data/model/test/test_modelutil.py | 50 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 data/model/test/test_modelutil.py diff --git a/data/model/modelutil.py b/data/model/modelutil.py index 1a6cdab49..ac8bcd804 100644 --- a/data/model/modelutil.py +++ b/data/model/modelutil.py @@ -25,9 +25,8 @@ def paginate(query, model, descending=False, page_token=None, limit=50, id_alias query = query.where(model.id <= start_id) else: query = query.where(model.id >= start_id) - else: - query = query.limit(limit + 1) + query = query.limit(limit + 1) return paginate_query(query, limit=limit, id_alias=id_alias) diff --git a/data/model/test/test_modelutil.py b/data/model/test/test_modelutil.py new file mode 100644 index 000000000..5da72be4a --- /dev/null +++ b/data/model/test/test_modelutil.py @@ -0,0 +1,50 @@ +import pytest + +from data.database import Role +from data.model.modelutil import paginate +from test.fixtures import * + +@pytest.mark.parametrize('page_size', [ + 10, + 20, + 50, + 100, + 200, + 500, + 1000, +]) +@pytest.mark.parametrize('descending', [ + False, + True, +]) +def test_paginate(page_size, descending, initialized_db): + # Add a bunch of rows into a test table (`Role`). + for i in range(0, 522): + Role.create(name='testrole%s' % i) + + query = Role.select().where(Role.name ** 'testrole%') + all_matching_roles = list(query) + assert len(all_matching_roles) == 522 + + # Paginate a query to lookup roles. + collected = [] + page_token = None + while True: + results, page_token = paginate(query, Role, limit=page_size, descending=descending, + page_token=page_token) + assert len(results) <= page_size + collected.extend(results) + + if page_token is None: + break + + assert len(results) == page_size + + for index, result in enumerate(results[1:]): + if descending: + assert result.id < results[index].id + else: + assert result.id > results[index].id + + assert len(collected) == len(all_matching_roles) + assert {c.id for c in collected} == {a.id for a in all_matching_roles} From e3248bde472ebb7df46ee5cd9cc0738944357a6b Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Mon, 14 May 2018 11:41:49 -0400 Subject: [PATCH 2/2] Small fixes to make loading of logs faster Removes filtering of log types where not necessary, removes filtering based on namespace when filtering based on repository (superfluous check that was causing issues in MySQL preventing the use of the correct index) and fix some other small issues around the API Fixes https://jira.coreos.com/browse/QUAY-931 --- data/model/log.py | 2 +- endpoints/api/logs.py | 22 ++++++++++------------ endpoints/api/logs_models_pre_oci.py | 8 ++++---- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/data/model/log.py b/data/model/log.py index e61f456f7..a8849f793 100644 --- a/data/model/log.py +++ b/data/model/log.py @@ -26,7 +26,7 @@ def _logs_query(selections, start_time, end_time, performer=None, repository=Non if performer: joined = joined.where(LogEntry.performer == performer) - if namespace: + if namespace and not repository: namespace_user = user.get_user_or_org(namespace) if namespace_user is None: raise DataModelException('Invalid namespace requested') diff --git a/endpoints/api/logs.py b/endpoints/api/logs.py index 9966aaa6c..f95be5d12 100644 --- a/endpoints/api/logs.py +++ b/endpoints/api/logs.py @@ -50,10 +50,10 @@ def get_logs(start_time, end_time, performer_name=None, repository_name=None, na include_namespace = namespace_name is None and repository_name is None return { - 'start_time': format_date(start_time), - 'end_time': format_date(end_time), - 'logs': [log.to_dict(kinds, include_namespace) for log in log_entry_page.logs], - }, log_entry_page.next_page_token + 'start_time': format_date(start_time), + 'end_time': format_date(end_time), + 'logs': [log.to_dict(kinds, include_namespace) for log in log_entry_page.logs], + }, log_entry_page.next_page_token def get_aggregate_logs(start_time, end_time, performer_name=None, repository=None, namespace=None, @@ -80,7 +80,6 @@ class RepositoryLogs(RepositoryParamResource): @parse_args() @query_param('starttime', 'Earliest time from which to get logs (%m/%d/%Y %Z)', type=str) @query_param('endtime', 'Latest time to which to get logs (%m/%d/%Y %Z)', type=str) - @query_param('page', 'The page number for the logs', type=int, default=1) @page_support() def get(self, namespace, repository, page_token, parsed_args): """ List the logs for the specified repository. """ @@ -89,8 +88,8 @@ class RepositoryLogs(RepositoryParamResource): start_time = parsed_args['starttime'] end_time = parsed_args['endtime'] - return get_logs(start_time, end_time, repository_name=repository, page_token=page_token, namespace_name=namespace, - ignore=SERVICE_LEVEL_LOG_KINDS) + return get_logs(start_time, end_time, repository_name=repository, page_token=page_token, + namespace_name=namespace) @resource('/v1/user/logs') @@ -111,8 +110,9 @@ class UserLogs(ApiResource): end_time = parsed_args['endtime'] user = get_authenticated_user() - return get_logs(start_time, end_time, performer_name=performer_name, namespace_name=user.username, - page_token=page_token, ignore=SERVICE_LEVEL_LOG_KINDS) + return get_logs(start_time, end_time, performer_name=performer_name, + namespace_name=user.username, page_token=page_token, + ignore=SERVICE_LEVEL_LOG_KINDS) @resource('/v1/organization//logs') @@ -126,7 +126,6 @@ class OrgLogs(ApiResource): @query_param('starttime', 'Earliest time from which to get logs. (%m/%d/%Y %Z)', type=str) @query_param('endtime', 'Latest time to which to get logs. (%m/%d/%Y %Z)', type=str) @query_param('performer', 'Username for which to filter logs.', type=str) - @query_param('page', 'The page number for the logs', type=int, default=1) @page_support() @require_scope(scopes.ORG_ADMIN) def get(self, orgname, page_token, parsed_args): @@ -160,8 +159,7 @@ class RepositoryAggregateLogs(RepositoryParamResource): start_time = parsed_args['starttime'] end_time = parsed_args['endtime'] - return get_aggregate_logs(start_time, end_time, repository=repository, namespace=namespace, - ignore=SERVICE_LEVEL_LOG_KINDS) + return get_aggregate_logs(start_time, end_time, repository=repository, namespace=namespace) @resource('/v1/user/aggregatelogs') diff --git a/endpoints/api/logs_models_pre_oci.py b/endpoints/api/logs_models_pre_oci.py index 8bbdddacc..da4d431a8 100644 --- a/endpoints/api/logs_models_pre_oci.py +++ b/endpoints/api/logs_models_pre_oci.py @@ -37,8 +37,8 @@ class PreOCIModel(LogEntryDataInterface): before it was changed to support the OCI specification. """ - def get_logs_query(self, start_time, end_time, performer_name=None, repository_name=None, namespace_name=None, - ignore=None, page_token=None): + def get_logs_query(self, start_time, end_time, performer_name=None, repository_name=None, + namespace_name=None, ignore=None, page_token=None): repo = None if repository_name and namespace_name: repo = model.repository.get_repository(namespace_name, repository_name) @@ -65,8 +65,8 @@ class PreOCIModel(LogEntryDataInterface): return False return True - def get_aggregated_logs(self, start_time, end_time, performer_name=None, repository_name=None, namespace_name=None, - ignore=None): + def get_aggregated_logs(self, start_time, end_time, performer_name=None, repository_name=None, + namespace_name=None, ignore=None): repo = None if repository_name and namespace_name: repo = model.repository.get_repository(namespace_name, repository_name)