From a2b97737eb82fa5fa99df1a502c395323961a655 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Tue, 5 Nov 2019 12:46:21 +0000 Subject: [PATCH] Add get_base_queryset method to PagesAPIEndpoint This method provides a single place to define a base queryset to be used as a base for filtering but also in the descendant_of and child_of filters. This improves parity between the Admin API and public API as well. --- wagtail/admin/api/endpoints.py | 42 ++++++++++++++--------------- wagtail/api/v2/endpoints.py | 43 +++++++++++++++++++----------- wagtail/api/v2/filters.py | 48 ++++------------------------------ 3 files changed, 54 insertions(+), 79 deletions(-) diff --git a/wagtail/admin/api/endpoints.py b/wagtail/admin/api/endpoints.py index 183af0a44..f68492b3b 100644 --- a/wagtail/admin/api/endpoints.py +++ b/wagtail/admin/api/endpoints.py @@ -3,10 +3,7 @@ from collections import OrderedDict from rest_framework.authentication import SessionAuthentication from wagtail.api.v2.endpoints import PagesAPIViewSet -from wagtail.api.v2.filters import ( - ChildOfFilter, DescendantOfFilter, FieldsFilter, OrderingFilter, - SearchFilter) -from wagtail.api.v2.utils import BadRequestError, filter_page_type, page_models_from_string +from wagtail.api.v2.utils import filter_page_type from wagtail.core.models import Page from .filters import ForExplorerFilter, HasChildrenFilter @@ -17,16 +14,10 @@ class PagesAdminAPIViewSet(PagesAPIViewSet): base_serializer_class = AdminPageSerializer authentication_classes = [SessionAuthentication] - # Use unrestricted child_of/descendant_of filters - # Add has_children filter - filter_backends = [ - FieldsFilter, - ChildOfFilter, - DescendantOfFilter, - ForExplorerFilter, + # Add has_children and for_explorer filters + filter_backends = PagesAPIViewSet.filter_backends + [ HasChildrenFilter, - OrderingFilter, - SearchFilter, + ForExplorerFilter, ] meta_fields = PagesAPIViewSet.meta_fields + [ @@ -57,16 +48,20 @@ class PagesAdminAPIViewSet(PagesAPIViewSet): 'has_children' ]) - def get_queryset(self): - request = self.request + def get_root_page(self): + """ + Returns the page that is used when the `&child_of=root` filter is used. + """ + return Page.get_first_root_node() - # Allow pages to be filtered to a specific type - try: - models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page')) - except (LookupError, ValueError): - raise BadRequestError("type doesn't exist") + def get_base_queryset(self, models=None): + """ + Returns a queryset containing all pages that can be seen by this user. - if not models: + This is used as the base for get_queryset and is also used to find the + parent pages when using the child_of and descendant_of filters as well. + """ + if models is None: models = [Page] if len(models) == 1: @@ -77,6 +72,11 @@ class PagesAdminAPIViewSet(PagesAPIViewSet): # Filter pages by specified models queryset = filter_page_type(queryset, models) + return queryset + + def get_queryset(self): + queryset = super().get_queryset() + # Hide root page # TODO: Add "include_root" flag queryset = queryset.exclude(depth=1).specific() diff --git a/wagtail/api/v2/endpoints.py b/wagtail/api/v2/endpoints.py index 2b552437c..2811355bc 100644 --- a/wagtail/api/v2/endpoints.py +++ b/wagtail/api/v2/endpoints.py @@ -14,9 +14,7 @@ from rest_framework.viewsets import GenericViewSet from wagtail.api import APIField from wagtail.core.models import Page -from .filters import ( - FieldsFilter, OrderingFilter, RestrictedChildOfFilter, RestrictedDescendantOfFilter, - SearchFilter) +from .filters import ChildOfFilter, DescendantOfFilter, FieldsFilter, OrderingFilter, SearchFilter from .pagination import WagtailPagination from .serializers import BaseSerializer, PageSerializer, get_serializer_class from .utils import ( @@ -366,8 +364,8 @@ class PagesAPIViewSet(BaseAPIViewSet): base_serializer_class = PageSerializer filter_backends = [ FieldsFilter, - RestrictedChildOfFilter, - RestrictedDescendantOfFilter, + ChildOfFilter, + DescendantOfFilter, OrderingFilter, SearchFilter ] @@ -401,16 +399,20 @@ class PagesAPIViewSet(BaseAPIViewSet): name = 'pages' model = Page - def get_queryset(self): - request = self.request + def get_root_page(self): + """ + Returns the page that is used when the `&child_of=root` filter is used. + """ + return self.request.site.root_page - # Allow pages to be filtered to a specific type - try: - models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page')) - except (LookupError, ValueError): - raise BadRequestError("type doesn't exist") + def get_base_queryset(self, models=None): + """ + Returns a queryset containing all pages that can be seen by this user. - if not models: + This is used as the base for get_queryset and is also used to find the + parent pages when using the child_of and descendant_of filters as well. + """ + if models is None: models = [Page] if len(models) == 1: @@ -425,14 +427,25 @@ class PagesAPIViewSet(BaseAPIViewSet): queryset = queryset.public().live() # Filter by site - if request.site: - queryset = queryset.descendant_of(request.site.root_page, inclusive=True) + if self.request.site: + queryset = queryset.descendant_of(self.request.site.root_page, inclusive=True) else: # No sites configured queryset = queryset.none() return queryset + def get_queryset(self): + request = self.request + + # Allow pages to be filtered to a specific type + try: + models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page')) + except (LookupError, ValueError): + raise BadRequestError("type doesn't exist") + + return self.get_base_queryset(models) + def get_object(self): base = super().get_object() return base.specific diff --git a/wagtail/api/v2/filters.py b/wagtail/api/v2/filters.py index 0356c2b8c..5322804dc 100644 --- a/wagtail/api/v2/filters.py +++ b/wagtail/api/v2/filters.py @@ -7,7 +7,7 @@ from wagtail.core.models import Page from wagtail.search.backends import get_search_backend from wagtail.search.backends.base import FilterFieldError, OrderByFieldError -from .utils import BadRequestError, pages_for_site, parse_boolean +from .utils import BadRequestError, parse_boolean class FieldsFilter(BaseFilterBackend): @@ -132,12 +132,6 @@ class ChildOfFilter(BaseFilterBackend): Implements the ?child_of filter used to filter the results to only contain pages that are direct children of the specified page. """ - def get_root_page(self, request): - return Page.get_first_root_node() - - def get_page_by_id(self, request, page_id): - return Page.objects.get(id=page_id) - def filter_queryset(self, request, queryset, view): if 'child_of' in request.GET: try: @@ -145,10 +139,10 @@ class ChildOfFilter(BaseFilterBackend): if parent_page_id < 0: raise ValueError() - parent_page = self.get_page_by_id(request, parent_page_id) + parent_page = view.get_base_queryset().get(id=parent_page_id) except ValueError: if request.GET['child_of'] == 'root': - parent_page = self.get_root_page(request) + parent_page = view.get_root_page() else: raise BadRequestError("child_of must be a positive integer") except Page.DoesNotExist: @@ -160,30 +154,11 @@ class ChildOfFilter(BaseFilterBackend): return queryset -class RestrictedChildOfFilter(ChildOfFilter): - """ - A restricted version of ChildOfFilter that only allows pages in the current - site to be specified. - """ - def get_root_page(self, request): - return request.site.root_page - - def get_page_by_id(self, request, page_id): - site_pages = pages_for_site(request.site) - return site_pages.get(id=page_id) - - class DescendantOfFilter(BaseFilterBackend): """ Implements the ?decendant_of filter which limits the set of pages to a particular branch of the page tree. """ - def get_root_page(self, request): - return Page.get_first_root_node() - - def get_page_by_id(self, request, page_id): - return Page.objects.get(id=page_id) - def filter_queryset(self, request, queryset, view): if 'descendant_of' in request.GET: if hasattr(queryset, '_filtered_by_child_of'): @@ -193,10 +168,10 @@ class DescendantOfFilter(BaseFilterBackend): if parent_page_id < 0: raise ValueError() - parent_page = self.get_page_by_id(request, parent_page_id) + parent_page = view.get_base_queryset().get(id=parent_page_id) except ValueError: if request.GET['descendant_of'] == 'root': - parent_page = self.get_root_page(request) + parent_page = view.get_root_page() else: raise BadRequestError("descendant_of must be a positive integer") except Page.DoesNotExist: @@ -205,16 +180,3 @@ class DescendantOfFilter(BaseFilterBackend): queryset = queryset.descendant_of(parent_page) return queryset - - -class RestrictedDescendantOfFilter(DescendantOfFilter): - """ - A restricted version of DecendantOfFilter that only allows pages in the current - site to be specified. - """ - def get_root_page(self, request): - return request.site.root_page - - def get_page_by_id(self, request, page_id): - site_pages = pages_for_site(request.site) - return site_pages.get(id=page_id)