From 67ff876542d6f9cfe600cd069bdb5cd1c429cd59 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Fri, 21 Aug 2015 12:00:04 +0100 Subject: [PATCH] Simplified logic for finding which endpoint contains model --- wagtail/contrib/wagtailapi/endpoints.py | 20 ++++---------------- wagtail/contrib/wagtailapi/renderers.py | 2 +- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index c6a2489fc..e5487203b 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -29,6 +29,7 @@ class BaseAPIEndpoint(GenericViewSet): pagination_class = WagtailPagination base_serializer_class = BaseSerializer filter_classes = [] + model = None # Set on subclass queryset = None # Set on subclasses or implement `get_queryset()`. known_query_parameters = frozenset([ @@ -164,10 +165,6 @@ class BaseAPIEndpoint(GenericViewSet): url(r'^(?P\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), ] - @classmethod - def has_model(cls, model): - return NotImplemented - class PagesAPIEndpoint(BaseAPIEndpoint): base_serializer_class = PageSerializer @@ -185,6 +182,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): ]) extra_api_fields = ['title'] name = 'pages' + model = Page def get_queryset(self): request = self.request @@ -213,10 +211,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): base = super(PagesAPIEndpoint, self).get_object() return base.specific - @classmethod - def has_model(cls, model): - return issubclass(model, Page) - class ImagesAPIEndpoint(BaseAPIEndpoint): queryset = get_image_model().objects.all().order_by('id') @@ -224,10 +218,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] name = 'images' - - @classmethod - def has_model(cls, model): - return model == get_image_model() + model = get_image_model() class DocumentsAPIEndpoint(BaseAPIEndpoint): @@ -236,7 +227,4 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] name = 'documents' - - @classmethod - def has_model(cls, model): - return model == Document + model = Document diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index f9828f3a2..ccae7e1d7 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -16,7 +16,7 @@ def get_full_url(request, path): def find_model_detail_view(model, endpoints): for endpoint in endpoints: - if endpoint.has_model(model): + if issubclass(model, endpoint.model): return 'wagtailapi_v1:%s:detail' % endpoint.name