From 76de8eab349722d53edf7a1fd42c1bf7db3e795f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 17:03:07 +0100 Subject: [PATCH] Refactor pagination --- wagtail/contrib/wagtailapi/endpoints.py | 87 +++++++----------------- wagtail/contrib/wagtailapi/pagination.py | 45 ++++++++++++ 2 files changed, 68 insertions(+), 64 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/pagination.py diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 1117e3a25..105cf2c87 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -27,6 +27,7 @@ from .filters import ( ChildOfFilter, DescendantOfFilter ) from .renderers import WagtailJSONRenderer +from .pagination import WagtailPagination from .utils import BadRequestError, URLPath, ObjectDetailURL @@ -83,6 +84,7 @@ def get_api_data(obj, fields): class BaseAPIEndpoint(GenericViewSet): renderer_classes = [WagtailJSONRenderer] + pagination_class = WagtailPagination filter_classes = [] known_query_parameters = frozenset([ @@ -179,34 +181,6 @@ class BaseAPIEndpoint(GenericViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) - def do_pagination(self, request, queryset): - """ - This performs limit/offset based pagination on the result set - Eg: ?limit=10&offset=20 -- Returns 10 items starting at item 20 - """ - limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20) - - try: - offset = int(request.GET.get('offset', 0)) - assert offset >= 0 - except (ValueError, AssertionError): - raise BadRequestError("offset must be a positive integer") - - try: - limit = int(request.GET.get('limit', min(20, limit_max))) - - if limit > limit_max: - raise BadRequestError("limit cannot be higher than %d" % limit_max) - - assert limit >= 0 - except (ValueError, AssertionError): - raise BadRequestError("limit must be a positive integer") - - start = offset - stop = offset + limit - - return queryset[start:stop] - @classmethod def get_urlpatterns(cls): """ @@ -234,6 +208,7 @@ class BaseAPIEndpoint(GenericViewSet): class PagesAPIEndpoint(BaseAPIEndpoint): + name = 'pages' known_query_parameters = BaseAPIEndpoint.known_query_parameters.union([ 'type', 'child_of', @@ -310,8 +285,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -319,16 +293,11 @@ class PagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('pages', [ - self.serialize_object(request, page, fields=fields) - for page in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, page, fields=fields) + for page in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): page = get_object_or_404(self.get_queryset(request), pk=pk).specific @@ -340,6 +309,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): + name = 'images' model = get_image_model() filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] @@ -357,8 +327,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -366,16 +335,11 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('images', [ - self.serialize_object(request, image, fields=fields) - for image in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, image, fields=fields) + for image in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): image = get_object_or_404(self.get_queryset(request), pk=pk) @@ -387,6 +351,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): + name = 'documents' filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] @@ -409,8 +374,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -418,16 +382,11 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('documents', [ - self.serialize_object(request, document, fields=fields) - for document in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, document, fields=fields) + for document in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): document = get_object_or_404(Document, pk=pk) diff --git a/wagtail/contrib/wagtailapi/pagination.py b/wagtail/contrib/wagtailapi/pagination.py new file mode 100644 index 000000000..6cb470e06 --- /dev/null +++ b/wagtail/contrib/wagtailapi/pagination.py @@ -0,0 +1,45 @@ +from collections import OrderedDict + +from django.conf import settings + +from rest_framework.pagination import BasePagination +from rest_framework.response import Response + +from .utils import BadRequestError + + +class WagtailPagination(BasePagination): + def paginate_queryset(self, queryset, request, view=None): + limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20) + + try: + offset = int(request.GET.get('offset', 0)) + assert offset >= 0 + except (ValueError, AssertionError): + raise BadRequestError("offset must be a positive integer") + + try: + limit = int(request.GET.get('limit', min(20, limit_max))) + + if limit > limit_max: + raise BadRequestError("limit cannot be higher than %d" % limit_max) + + assert limit >= 0 + except (ValueError, AssertionError): + raise BadRequestError("limit must be a positive integer") + + start = offset + stop = offset + limit + + self.view = view + self.total_count = queryset.count() + return queryset[start:stop] + + def get_paginated_response(self, data): + data = OrderedDict([ + ('meta', OrderedDict([ + ('total_count', self.total_count), + ])), + (self.view.name, data), + ]) + return Response(data)