import dateutil.parser from datetime import datetime from peewee import SQL def paginate(query, model, descending=False, page_token=None, limit=50, sort_field_alias=None, max_page=None, sort_field_name=None): """ Paginates the given query using an field range, starting at the optional page_token. Returns a *list* of matching results along with an unencrypted page_token for the next page, if any. If descending is set to True, orders by the field descending rather than ascending. """ # Note: We use the sort_field_alias for the order_by, but not the where below. The alias is # necessary for certain queries that use unions in MySQL, as it gets confused on which field # to order by. The where clause, on the other hand, cannot use the alias because Postgres does # not allow aliases in where clauses. sort_field_name = sort_field_name or 'id' sort_field = getattr(model, sort_field_name) if sort_field_alias is not None: sort_field_name = sort_field_alias sort_field = SQL(sort_field_alias) if descending: query = query.order_by(sort_field.desc()) else: query = query.order_by(sort_field) start_index = pagination_start(page_token) if start_index is not None: if descending: query = query.where(sort_field <= start_index) else: query = query.where(sort_field >= start_index) query = query.limit(limit + 1) page_number = (page_token.get('page_number') or None) if page_token else None if page_number is not None and max_page is not None and page_number > max_page: return [], None return paginate_query(query, limit=limit, sort_field_name=sort_field_name, page_number=page_number) def pagination_start(page_token=None): """ Returns the start index for pagination for the given page token. Will return None if None. """ if page_token is not None: start_index = page_token.get('start_index') if page_token.get('is_datetime'): start_index = dateutil.parser.parse(start_index) return start_index return None def paginate_query(query, limit=50, sort_field_name=None, page_number=None): """ Executes the given query and returns a page's worth of results, as well as the page token for the next page (if any). """ results = list(query) page_token = None if len(results) > limit: start_index = getattr(results[limit], sort_field_name or 'id') is_datetime = False if isinstance(start_index, datetime): start_index = start_index.isoformat() + "Z" is_datetime = True page_token = { 'start_index': start_index, 'page_number': page_number + 1 if page_number else 1, 'is_datetime': is_datetime, } return results[0:limit], page_token