From b6a4318379ac90e5b62efa26361c9933b28cea8d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 13:40:19 +0100 Subject: [PATCH] Refactor to use Responses and Renderers. --- wagtail/contrib/wagtailapi/api.py | 86 ++----------------------- wagtail/contrib/wagtailapi/endpoints.py | 77 +++++++++++++--------- wagtail/contrib/wagtailapi/renderers.py | 49 ++++++++++++++ wagtail/contrib/wagtailapi/utils.py | 22 +++++++ 4 files changed, 122 insertions(+), 112 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/renderers.py diff --git a/wagtail/contrib/wagtailapi/api.py b/wagtail/contrib/wagtailapi/api.py index 586e3d524..33877429f 100644 --- a/wagtail/contrib/wagtailapi/api.py +++ b/wagtail/contrib/wagtailapi/api.py @@ -1,95 +1,21 @@ -import json -from functools import wraps - from django.conf.urls import url, include -from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotFound, Http404 -from django.core.serializers.json import DjangoJSONEncoder -from django.core.urlresolvers import reverse -from taggit.managers import _TaggableManager -from taggit.models import Tag - -from wagtail.utils.urlpatterns import decorate_urlpatterns -from wagtail.wagtailcore.blocks import StreamValue - -from .endpoints import URLPath, ObjectDetailURL, PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint -from .utils import BadRequestError, get_base_url - - -def get_full_url(request, path): - base_url = get_base_url(request) or '' - return base_url + path +from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint class API(object): def __init__(self, endpoints): self.endpoints = endpoints - def find_model_detail_view(self, model): - for endpoint_name, endpoint in self.endpoints.items(): - if endpoint.has_model(model): - return 'wagtailapi_v1:%s:detail' % endpoint_name - - def make_response(self, request, data, response_cls=HttpResponse): - api = self - - class WagtailAPIJSONEncoder(DjangoJSONEncoder): - def default(self, o): - if isinstance(o, _TaggableManager): - return list(o.all()) - elif isinstance(o, Tag): - return o.name - elif isinstance(o, URLPath): - return get_full_url(request, o.path) - elif isinstance(o, ObjectDetailURL): - view = api.find_model_detail_view(o.model) - - if view: - return get_full_url(request, reverse(view, args=(o.pk, ))) - else: - return None - elif isinstance(o, StreamValue): - return o.stream_block.get_prep_value(o) - else: - return super(WagtailAPIJSONEncoder, self).default(o) - - return response_cls( - json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder), - content_type='application/json' - ) - - def api_view(self, view): - """ - This is a decorator that is applied to all API views. - - It is responsible for serialising the responses from the endpoints - and handling errors. - """ - @wraps(view) - def wrapper(request, *args, **kwargs): - # Catch exceptions and format them as JSON documents - try: - return self.make_response(request, view(request, *args, **kwargs)) - except Http404 as e: - return self.make_response(request, { - 'message': str(e) - }, response_cls=HttpResponseNotFound) - except BadRequestError as e: - return self.make_response(request, { - 'message': str(e) - }, response_cls=HttpResponseBadRequest) - - return wrapper - def get_urlpatterns(self): - return decorate_urlpatterns([ + return [ url(r'^%s/' % name, include(endpoint.get_urlpatterns(), namespace=name)) for name, endpoint in self.endpoints.items() - ], self.api_view) + ] v1 = API({ - 'pages': PagesAPIEndpoint(), - 'images': ImagesAPIEndpoint(), - 'documents': DocumentsAPIEndpoint(), + 'pages': PagesAPIEndpoint, + 'images': ImagesAPIEndpoint, + 'documents': DocumentsAPIEndpoint, }) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index c50d7eed6..4d6e15b72 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -10,6 +10,10 @@ from django.utils.encoding import force_text from django.shortcuts import get_object_or_404 from django.conf.urls import url from django.conf import settings +from django.http import Http404 + +from rest_framework import status +from rest_framework.response import Response from rest_framework.viewsets import ViewSet from wagtail.wagtailcore.models import Page @@ -19,29 +23,8 @@ from wagtail.wagtailcore.utils import resolve_model_string from wagtail.wagtailsearch.backends import get_search_backend from wagtail.utils.compat import get_related_model -from .utils import BadRequestError - - -class URLPath(object): - """ - This class represents a URL path that should be converted to a full URL. - - It is used when the domain that should be used is not known at the time - the URL was generated. It will get resolved to a full URL during - serialisation in api.py. - - One example use case is the documents endpoint adding download URLs into - the JSON. The endpoint does not know the domain name to use at the time so - returns one of these instead. - """ - def __init__(self, path): - self.path = path - - -class ObjectDetailURL(object): - def __init__(self, model, pk): - self.model = model - self.pk = pk +from .renderers import WagtailJSONRenderer +from .utils import BadRequestError, URLPath, ObjectDetailURL def get_api_data(obj, fields): @@ -96,6 +79,8 @@ def get_api_data(obj, fields): class BaseAPIEndpoint(ViewSet): + renderer_classes = [WagtailJSONRenderer] + known_query_parameters = frozenset([ 'limit', 'offset', @@ -104,6 +89,15 @@ class BaseAPIEndpoint(ViewSet): 'search', ]) + def handle_exception(self, exc): + if isinstance(exc, Http404): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_404_NOT_FOUND) + elif isinstance(exc, BadRequestError): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_400_BAD_REQUEST) + return super(BaseAPIEndpoint, self).handle_exception(exc) + def listing_view(self, request): return NotImplemented @@ -300,15 +294,28 @@ class BaseAPIEndpoint(ViewSet): return queryset[start:stop] - def get_urlpatterns(self): + @classmethod + def get_urlpatterns(cls): """ This returns a list of URL patterns for the endpoint """ return [ - url(r'^$', self.listing_view, name='listing'), - url(r'^(\d+)/$', self.detail_view, name='detail'), + url(r'^$', cls.as_view({'get': 'listing_view'}), name='listing'), + url(r'^(\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), ] + def find_model_detail_view(self, model): + # TODO: Needs refactoring. This is currently duplicated, and also + # does a bit of a dance around instantiating these classes. + endpoints = { + 'pages': PagesAPIEndpoint(), + 'images': ImagesAPIEndpoint(), + 'documents': DocumentsAPIEndpoint(), + } + for endpoint_name, endpoint in endpoints.items(): + if endpoint.has_model(model): + return 'wagtailapi_v1:%s:detail' % endpoint_name + def has_model(self, model): return False @@ -443,7 +450,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -452,10 +459,12 @@ class PagesAPIEndpoint(BaseAPIEndpoint): for page in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): page = get_object_or_404(self.get_queryset(request), pk=pk).specific - return self.serialize_object(request, page, all_fields=True, show_details=True) + data = self.serialize_object(request, page, all_fields=True, show_details=True) + return Response(data) def has_model(self, model): return issubclass(model, Page) @@ -497,7 +506,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -506,10 +515,12 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): for image in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): image = get_object_or_404(self.get_queryset(request), pk=pk) - return self.serialize_object(request, image, all_fields=True) + data = self.serialize_object(request, image, all_fields=True) + return Response(data) def has_model(self, model): return model == self.model @@ -555,7 +566,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -564,10 +575,12 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): for document in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): document = get_object_or_404(Document, pk=pk) - return self.serialize_object(request, document, all_fields=True, show_details=True) + data = self.serialize_object(request, document, all_fields=True, show_details=True) + return Response(data) def has_model(self, model): return model == Document diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py new file mode 100644 index 000000000..c58ea32c6 --- /dev/null +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -0,0 +1,49 @@ +import json + +from django.core.serializers.json import DjangoJSONEncoder +from django.core.urlresolvers import reverse + +from rest_framework import renderers + +from taggit.managers import _TaggableManager +from taggit.models import Tag + +from wagtail.wagtailcore.blocks import StreamValue + +from .utils import URLPath, ObjectDetailURL, get_base_url + + +def get_full_url(request, path): + base_url = get_base_url(request) or '' + return base_url + path + + +class WagtailJSONRenderer(renderers.BaseRenderer): + media_type = 'application/json' + charset = None + + def render(self, data, media_type=None, renderer_context=None): + endpoint = renderer_context['view'] + request = renderer_context['request'] + + class WagtailAPIJSONEncoder(DjangoJSONEncoder): + def default(self, o): + if isinstance(o, _TaggableManager): + return list(o.all()) + elif isinstance(o, Tag): + return o.name + elif isinstance(o, URLPath): + return get_full_url(request, o.path) + elif isinstance(o, ObjectDetailURL): + view = endpoint.find_model_detail_view(o.model) + + if view: + return get_full_url(request, reverse(view, args=(o.pk, ))) + else: + return None + elif isinstance(o, StreamValue): + return o.stream_block.get_prep_value(o) + else: + return super(WagtailAPIJSONEncoder, self).default(o) + + return json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder) diff --git a/wagtail/contrib/wagtailapi/utils.py b/wagtail/contrib/wagtailapi/utils.py index 11af445d2..483e4f51a 100644 --- a/wagtail/contrib/wagtailapi/utils.py +++ b/wagtail/contrib/wagtailapi/utils.py @@ -14,3 +14,25 @@ def get_base_url(request=None): base_url_parsed = urlparse(base_url) return base_url_parsed.scheme + '://' + base_url_parsed.netloc + + +class URLPath(object): + """ + This class represents a URL path that should be converted to a full URL. + + It is used when the domain that should be used is not known at the time + the URL was generated. It will get resolved to a full URL during + serialisation in api.py. + + One example use case is the documents endpoint adding download URLs into + the JSON. The endpoint does not know the domain name to use at the time so + returns one of these instead. + """ + def __init__(self, path): + self.path = path + + +class ObjectDetailURL(object): + def __init__(self, model, pk): + self.model = model + self.pk = pk