From e1978f6606ab6ed57ad822996593d1bfa21c3dfe Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 16:17:40 +0100 Subject: [PATCH] Refactor filters --- wagtail/contrib/wagtailapi/endpoints.py | 177 +++--------------------- wagtail/contrib/wagtailapi/filters.py | 148 ++++++++++++++++++++ 2 files changed, 168 insertions(+), 157 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/filters.py diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 4d6e15b72..594465eaf 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -14,15 +14,18 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response -from rest_framework.viewsets import ViewSet +from rest_framework.viewsets import GenericViewSet from wagtail.wagtailcore.models import Page from wagtail.wagtailimages.models import get_image_model from wagtail.wagtaildocs.models import Document from wagtail.wagtailcore.utils import resolve_model_string -from wagtail.wagtailsearch.backends import get_search_backend from wagtail.utils.compat import get_related_model +from .filters import ( + FieldsFilter, OrderingFilter, SearchFilter, + ChildOfFilter, DescendantOfFilter +) from .renderers import WagtailJSONRenderer from .utils import BadRequestError, URLPath, ObjectDetailURL @@ -78,8 +81,9 @@ def get_api_data(obj, fields): continue -class BaseAPIEndpoint(ViewSet): +class BaseAPIEndpoint(GenericViewSet): renderer_classes = [WagtailJSONRenderer] + filter_classes = [] known_query_parameters = frozenset([ 'limit', @@ -174,98 +178,6 @@ class BaseAPIEndpoint(ViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) - def do_field_filtering(self, request, queryset): - """ - This performs field level filtering on the result set - Eg: ?title=James Joyce - """ - fields = set(self.get_api_fields(queryset.model)).union({'id'}) - - for field_name, value in request.GET.items(): - if field_name in fields: - field = getattr(queryset.model, field_name, None) - - if isinstance(field, _TaggableManager): - for tag in value.split(','): - queryset = queryset.filter(**{field_name + '__name': tag}) - - # Stick a message on the queryset to indicate that tag filtering has been performed - # This will let the do_search method know that it must raise an error as searching - # and tag filtering at the same time is not supported - queryset._filtered_by_tag = True - else: - queryset = queryset.filter(**{field_name: value}) - - return queryset - - def do_ordering(self, request, queryset): - """ - This applies ordering to the result set - Eg: ?order=title - - It also supports reverse ordering - Eg: ?order=-title - - And random ordering - Eg: ?order=random - """ - if 'order' in request.GET: - # Prevent ordering while searching - if 'search' in request.GET: - raise BadRequestError("ordering with a search query is not supported") - - order_by = request.GET['order'] - - # Random ordering - if order_by == 'random': - # Prevent ordering by random with offset - if 'offset' in request.GET: - raise BadRequestError("random ordering with offset is not supported") - - return queryset.order_by('?') - - # Check if reverse ordering is set - if order_by.startswith('-'): - reverse_order = True - order_by = order_by[1:] - else: - reverse_order = False - - # Add ordering - if order_by == 'id' or order_by in self.get_api_fields(queryset.model): - queryset = queryset.order_by(order_by) - else: - # Unknown field - raise BadRequestError("cannot order by '%s' (unknown field)" % order_by) - - # Reverse order - if reverse_order: - queryset = queryset.reverse() - - return queryset - - def do_search(self, request, queryset): - """ - This performs a full-text search on the result set - Eg: ?search=James Joyce - """ - search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True) - - if 'search' in request.GET: - if not search_enabled: - raise BadRequestError("search is disabled") - - # Searching and filtering by tag at the same time is not supported - if getattr(queryset, '_filtered_by_tag', False): - raise BadRequestError("filtering by tag with a search query is not supported") - - search_query = request.GET['search'] - - sb = get_search_backend() - queryset = sb.search(search_query, queryset) - - return queryset - def do_pagination(self, request, queryset): """ This performs limit/offset based pagination on the result set @@ -326,6 +238,10 @@ class PagesAPIEndpoint(BaseAPIEndpoint): 'child_of', 'descendant_of', ]) + filter_backends = [ + FieldsFilter, ChildOfFilter, DescendantOfFilter, + OrderingFilter, SearchFilter + ] def get_queryset(self, request, model=Page): # Get live pages that are not in a private section @@ -385,42 +301,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): except LookupError: raise BadRequestError("type doesn't exist") - def do_child_of_filter(self, request, queryset): - if 'child_of' in request.GET: - try: - parent_page_id = int(request.GET['child_of']) - assert parent_page_id >= 0 - except (ValueError, AssertionError): - raise BadRequestError("child_of must be a positive integer") - - try: - parent_page = self.get_queryset(request).get(id=parent_page_id) - queryset = queryset.child_of(parent_page) - queryset._filtered_by_child_of = True - return queryset - except Page.DoesNotExist: - raise BadRequestError("parent page doesn't exist") - - return queryset - - def do_descendant_of_filter(self, request, queryset): - if 'descendant_of' in request.GET: - if getattr(queryset, '_filtered_by_child_of', False): - raise BadRequestError("filtering by descendant_of with child_of is not supported") - try: - ancestor_page_id = int(request.GET['descendant_of']) - assert ancestor_page_id >= 0 - except (ValueError, AssertionError): - raise BadRequestError("descendant_of must be a positive integer") - - try: - ancestor_page = self.get_queryset(request).get(id=ancestor_page_id) - return queryset.descendant_of(ancestor_page) - except Page.DoesNotExist: - raise BadRequestError("ancestor page doesn't exist") - - return queryset - def listing_view(self, request): # Get model and queryset model = self.get_model(request) @@ -429,16 +309,8 @@ class PagesAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - queryset = self.do_child_of_filter(request, queryset) - queryset = self.do_descendant_of_filter(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ancestor/Descendant, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() @@ -472,6 +344,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): model = get_image_model() + filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] def get_queryset(self, request): return self.model.objects.all().order_by('id') @@ -487,14 +360,8 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() @@ -527,6 +394,8 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): + filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] + def get_api_fields(self, model): api_fields = ['title', 'tags'] api_fields.extend(super(DocumentsAPIEndpoint, self).get_api_fields(model)) @@ -547,14 +416,8 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() diff --git a/wagtail/contrib/wagtailapi/filters.py b/wagtail/contrib/wagtailapi/filters.py new file mode 100644 index 000000000..633769864 --- /dev/null +++ b/wagtail/contrib/wagtailapi/filters.py @@ -0,0 +1,148 @@ +from django.conf import settings + +from rest_framework.filters import BaseFilterBackend + +from taggit.managers import _TaggableManager + +from wagtail.wagtailcore.models import Page +from wagtail.wagtailsearch.backends import get_search_backend + +from .utils import BadRequestError + + +class FieldsFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This performs field level filtering on the result set + Eg: ?title=James Joyce + """ + fields = set(view.get_api_fields(queryset.model)).union({'id'}) + + for field_name, value in request.GET.items(): + if field_name in fields: + field = getattr(queryset.model, field_name, None) + + if isinstance(field, _TaggableManager): + for tag in value.split(','): + queryset = queryset.filter(**{field_name + '__name': tag}) + + # Stick a message on the queryset to indicate that tag filtering has been performed + # This will let the do_search method know that it must raise an error as searching + # and tag filtering at the same time is not supported + queryset._filtered_by_tag = True + else: + queryset = queryset.filter(**{field_name: value}) + + return queryset + + +class OrderingFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This applies ordering to the result set + Eg: ?order=title + + It also supports reverse ordering + Eg: ?order=-title + + And random ordering + Eg: ?order=random + """ + if 'order' in request.GET: + # Prevent ordering while searching + if 'search' in request.GET: + raise BadRequestError("ordering with a search query is not supported") + + order_by = request.GET['order'] + + # Random ordering + if order_by == 'random': + # Prevent ordering by random with offset + if 'offset' in request.GET: + raise BadRequestError("random ordering with offset is not supported") + + return queryset.order_by('?') + + # Check if reverse ordering is set + if order_by.startswith('-'): + reverse_order = True + order_by = order_by[1:] + else: + reverse_order = False + + # Add ordering + if order_by == 'id' or order_by in view.get_api_fields(queryset.model): + queryset = queryset.order_by(order_by) + else: + # Unknown field + raise BadRequestError("cannot order by '%s' (unknown field)" % order_by) + + # Reverse order + if reverse_order: + queryset = queryset.reverse() + + return queryset + + +class SearchFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This performs a full-text search on the result set + Eg: ?search=James Joyce + """ + search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True) + + if 'search' in request.GET: + if not search_enabled: + raise BadRequestError("search is disabled") + + # Searching and filtering by tag at the same time is not supported + if getattr(queryset, '_filtered_by_tag', False): + raise BadRequestError("filtering by tag with a search query is not supported") + + search_query = request.GET['search'] + + sb = get_search_backend() + queryset = sb.search(search_query, queryset) + + return queryset + + +class ChildOfFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + if 'child_of' in request.GET: + try: + parent_page_id = int(request.GET['child_of']) + assert parent_page_id >= 0 + except (ValueError, AssertionError): + raise BadRequestError("child_of must be a positive integer") + + try: + parent_page = view.get_queryset(request).get(id=parent_page_id) + queryset = queryset.child_of(parent_page) + queryset._filtered_by_child_of = True + return queryset + except Page.DoesNotExist: + raise BadRequestError("parent page doesn't exist") + + return queryset + + +class DescendantOfFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + if 'descendant_of' in request.GET: + if getattr(queryset, '_filtered_by_child_of', False): + raise BadRequestError("filtering by descendant_of with child_of is not supported") + try: + ancestor_page_id = int(request.GET['descendant_of']) + assert ancestor_page_id >= 0 + except (ValueError, AssertionError): + raise BadRequestError("descendant_of must be a positive integer") + + try: + ancestor_page = view.get_queryset(request).get(id=ancestor_page_id) + return queryset.descendant_of(ancestor_page) + except Page.DoesNotExist: + raise BadRequestError("ancestor page doesn't exist") + + return queryset