diff --git a/wagtail/api/v2/endpoints.py b/wagtail/api/v2/endpoints.py index b045140ce..6474ec266 100644 --- a/wagtail/api/v2/endpoints.py +++ b/wagtail/api/v2/endpoints.py @@ -18,7 +18,7 @@ from wagtail.wagtaildocs.models import Document from .filters import ( FieldsFilter, OrderingFilter, SearchFilter, - ChildOfFilter, DescendantOfFilter + RestrictedChildOfFilter, RestrictedDescendantOfFilter ) from .pagination import WagtailPagination from .serializers import BaseSerializer, PageSerializer, DocumentSerializer, ImageSerializer, get_serializer_class @@ -189,8 +189,8 @@ class PagesAPIEndpoint(BaseAPIEndpoint): base_serializer_class = PageSerializer filter_backends = [ FieldsFilter, - ChildOfFilter, - DescendantOfFilter, + RestrictedChildOfFilter, + RestrictedDescendantOfFilter, OrderingFilter, SearchFilter ] diff --git a/wagtail/api/v2/filters.py b/wagtail/api/v2/filters.py index dedb69813..dc989130c 100644 --- a/wagtail/api/v2/filters.py +++ b/wagtail/api/v2/filters.py @@ -110,42 +110,73 @@ class SearchFilter(BaseFilterBackend): class ChildOfFilter(BaseFilterBackend): + """ + Implements the ?child_of filter used to filter the results to only contain + pages that are direct children of the specified page. + """ + def get_page_by_id(self, request, page_id): + return Page.objects.get(id=page_id) + def filter_queryset(self, request, queryset, view): if 'child_of' in request.GET: try: parent_page_id = int(request.GET['child_of']) assert parent_page_id >= 0 + + parent_page = self.get_page_by_id(request, parent_page_id) except (ValueError, AssertionError): raise BadRequestError("child_of must be a positive integer") - - site_pages = pages_for_site(request.site) - try: - parent_page = site_pages.get(id=parent_page_id) - queryset = queryset.child_of(parent_page) - queryset._filtered_by_child_of = True - return queryset except Page.DoesNotExist: raise BadRequestError("parent page doesn't exist") + queryset = queryset.child_of(parent_page) + queryset._filtered_by_child_of = True + return queryset +class RestrictedChildOfFilter(ChildOfFilter): + """ + A restricted version of ChildOfFilter that only allows pages in the current + site to be specified. + """ + def get_page_by_id(self, request, page_id): + site_pages = pages_for_site(request.site) + return site_pages.get(id=page_id) + + class DescendantOfFilter(BaseFilterBackend): + """ + Implements the ?decendant_of filter which limits the set of pages to a + particular branch of the page tree. + """ + def get_page_by_id(self, request, page_id): + return Page.objects.get(id=page_id) + def filter_queryset(self, request, queryset, view): if 'descendant_of' in request.GET: if getattr(queryset, '_filtered_by_child_of', False): raise BadRequestError("filtering by descendant_of with child_of is not supported") try: - ancestor_page_id = int(request.GET['descendant_of']) - assert ancestor_page_id >= 0 + parent_page_id = int(request.GET['descendant_of']) + assert parent_page_id >= 0 + + parent_page = self.get_page_by_id(request, parent_page_id) except (ValueError, AssertionError): raise BadRequestError("descendant_of must be a positive integer") - - site_pages = pages_for_site(request.site) - try: - ancestor_page = site_pages.get(id=ancestor_page_id) - return queryset.descendant_of(ancestor_page) except Page.DoesNotExist: raise BadRequestError("ancestor page doesn't exist") + queryset = queryset.descendant_of(parent_page) + return queryset + + +class RestrictedDescendantOfFilter(DescendantOfFilter): + """ + A restricted version of DecendantOfFilter that only allows pages in the current + site to be specified. + """ + def get_page_by_id(self, request, page_id): + site_pages = pages_for_site(request.site) + return site_pages.get(id=page_id)