diff --git a/data/model/log.py b/data/model/log.py index e61f456f7..a8849f793 100644 --- a/data/model/log.py +++ b/data/model/log.py @@ -26,7 +26,7 @@ def _logs_query(selections, start_time, end_time, performer=None, repository=Non if performer: joined = joined.where(LogEntry.performer == performer) - if namespace: + if namespace and not repository: namespace_user = user.get_user_or_org(namespace) if namespace_user is None: raise DataModelException('Invalid namespace requested') diff --git a/data/model/modelutil.py b/data/model/modelutil.py index 1a6cdab49..ac8bcd804 100644 --- a/data/model/modelutil.py +++ b/data/model/modelutil.py @@ -25,9 +25,8 @@ def paginate(query, model, descending=False, page_token=None, limit=50, id_alias query = query.where(model.id <= start_id) else: query = query.where(model.id >= start_id) - else: - query = query.limit(limit + 1) + query = query.limit(limit + 1) return paginate_query(query, limit=limit, id_alias=id_alias) diff --git a/data/model/test/test_modelutil.py b/data/model/test/test_modelutil.py new file mode 100644 index 000000000..5da72be4a --- /dev/null +++ b/data/model/test/test_modelutil.py @@ -0,0 +1,50 @@ +import pytest + +from data.database import Role +from data.model.modelutil import paginate +from test.fixtures import * + +@pytest.mark.parametrize('page_size', [ + 10, + 20, + 50, + 100, + 200, + 500, + 1000, +]) +@pytest.mark.parametrize('descending', [ + False, + True, +]) +def test_paginate(page_size, descending, initialized_db): + # Add a bunch of rows into a test table (`Role`). + for i in range(0, 522): + Role.create(name='testrole%s' % i) + + query = Role.select().where(Role.name ** 'testrole%') + all_matching_roles = list(query) + assert len(all_matching_roles) == 522 + + # Paginate a query to lookup roles. + collected = [] + page_token = None + while True: + results, page_token = paginate(query, Role, limit=page_size, descending=descending, + page_token=page_token) + assert len(results) <= page_size + collected.extend(results) + + if page_token is None: + break + + assert len(results) == page_size + + for index, result in enumerate(results[1:]): + if descending: + assert result.id < results[index].id + else: + assert result.id > results[index].id + + assert len(collected) == len(all_matching_roles) + assert {c.id for c in collected} == {a.id for a in all_matching_roles} diff --git a/endpoints/api/logs.py b/endpoints/api/logs.py index 9966aaa6c..f95be5d12 100644 --- a/endpoints/api/logs.py +++ b/endpoints/api/logs.py @@ -50,10 +50,10 @@ def get_logs(start_time, end_time, performer_name=None, repository_name=None, na include_namespace = namespace_name is None and repository_name is None return { - 'start_time': format_date(start_time), - 'end_time': format_date(end_time), - 'logs': [log.to_dict(kinds, include_namespace) for log in log_entry_page.logs], - }, log_entry_page.next_page_token + 'start_time': format_date(start_time), + 'end_time': format_date(end_time), + 'logs': [log.to_dict(kinds, include_namespace) for log in log_entry_page.logs], + }, log_entry_page.next_page_token def get_aggregate_logs(start_time, end_time, performer_name=None, repository=None, namespace=None, @@ -80,7 +80,6 @@ class RepositoryLogs(RepositoryParamResource): @parse_args() @query_param('starttime', 'Earliest time from which to get logs (%m/%d/%Y %Z)', type=str) @query_param('endtime', 'Latest time to which to get logs (%m/%d/%Y %Z)', type=str) - @query_param('page', 'The page number for the logs', type=int, default=1) @page_support() def get(self, namespace, repository, page_token, parsed_args): """ List the logs for the specified repository. """ @@ -89,8 +88,8 @@ class RepositoryLogs(RepositoryParamResource): start_time = parsed_args['starttime'] end_time = parsed_args['endtime'] - return get_logs(start_time, end_time, repository_name=repository, page_token=page_token, namespace_name=namespace, - ignore=SERVICE_LEVEL_LOG_KINDS) + return get_logs(start_time, end_time, repository_name=repository, page_token=page_token, + namespace_name=namespace) @resource('/v1/user/logs') @@ -111,8 +110,9 @@ class UserLogs(ApiResource): end_time = parsed_args['endtime'] user = get_authenticated_user() - return get_logs(start_time, end_time, performer_name=performer_name, namespace_name=user.username, - page_token=page_token, ignore=SERVICE_LEVEL_LOG_KINDS) + return get_logs(start_time, end_time, performer_name=performer_name, + namespace_name=user.username, page_token=page_token, + ignore=SERVICE_LEVEL_LOG_KINDS) @resource('/v1/organization//logs') @@ -126,7 +126,6 @@ class OrgLogs(ApiResource): @query_param('starttime', 'Earliest time from which to get logs. (%m/%d/%Y %Z)', type=str) @query_param('endtime', 'Latest time to which to get logs. (%m/%d/%Y %Z)', type=str) @query_param('performer', 'Username for which to filter logs.', type=str) - @query_param('page', 'The page number for the logs', type=int, default=1) @page_support() @require_scope(scopes.ORG_ADMIN) def get(self, orgname, page_token, parsed_args): @@ -160,8 +159,7 @@ class RepositoryAggregateLogs(RepositoryParamResource): start_time = parsed_args['starttime'] end_time = parsed_args['endtime'] - return get_aggregate_logs(start_time, end_time, repository=repository, namespace=namespace, - ignore=SERVICE_LEVEL_LOG_KINDS) + return get_aggregate_logs(start_time, end_time, repository=repository, namespace=namespace) @resource('/v1/user/aggregatelogs') diff --git a/endpoints/api/logs_models_pre_oci.py b/endpoints/api/logs_models_pre_oci.py index 8bbdddacc..da4d431a8 100644 --- a/endpoints/api/logs_models_pre_oci.py +++ b/endpoints/api/logs_models_pre_oci.py @@ -37,8 +37,8 @@ class PreOCIModel(LogEntryDataInterface): before it was changed to support the OCI specification. """ - def get_logs_query(self, start_time, end_time, performer_name=None, repository_name=None, namespace_name=None, - ignore=None, page_token=None): + def get_logs_query(self, start_time, end_time, performer_name=None, repository_name=None, + namespace_name=None, ignore=None, page_token=None): repo = None if repository_name and namespace_name: repo = model.repository.get_repository(namespace_name, repository_name) @@ -65,8 +65,8 @@ class PreOCIModel(LogEntryDataInterface): return False return True - def get_aggregated_logs(self, start_time, end_time, performer_name=None, repository_name=None, namespace_name=None, - ignore=None): + def get_aggregated_logs(self, start_time, end_time, performer_name=None, repository_name=None, + namespace_name=None, ignore=None): repo = None if repository_name and namespace_name: repo = model.repository.get_repository(namespace_name, repository_name)