diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py index c0f26101c..4cd323e16 100644 --- a/wagtail/contrib/wagtailapi/serializers.py +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -3,13 +3,13 @@ from __future__ import absolute_import from collections import OrderedDict from django.db import models -from django.utils.encoding import force_text from modelcluster.models import get_all_child_relations from taggit.managers import TaggableManager from rest_framework.serializers import BaseSerializer +from rest_framework.fields import Field, ReadOnlyField from wagtail.utils.compat import get_related_model from wagtail.wagtailcore.models import Page @@ -18,30 +18,14 @@ from wagtail.wagtailcore import fields as wagtailcore_fields from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site -class Field(object): - def __init__(self, field_name): - self.field_name = field_name - - def get_attribute(self, instance): - return getattr(instance, self.field_name) - - def to_representation(self, value): - return value - - -class AttrField(Field): - def to_representation(self, value): - return force_text(value, strings_only=True) - - class ChildRelationField(Field): - def __init__(self, field_name, fields): - self.field_name = field_name - self.fields = fields + def __init__(self, *args, **kwargs): + self.child_fields = kwargs.pop('child_fields') + super(ChildRelationField, self).__init__(*args, **kwargs) def to_representation(self, value): return [ - dict(get_api_data(child_object, self.fields)) + dict(get_api_data(child_object, self.child_fields)) for child_object in value.all() ] @@ -82,7 +66,7 @@ def get_serializer_fields(model, fields): for field_name in fields: # Check child relations if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): - yield ChildRelationField(field_name, child_relations[field_name].api_fields) + yield field_name, ChildRelationField, {'child_fields': child_relations[field_name].api_fields} continue # Check django fields @@ -90,13 +74,13 @@ def get_serializer_fields(model, fields): field = model._meta.get_field(field_name) if field.rel and isinstance(field.rel, models.ManyToOneRel): - yield RelatedObjectField(field_name) + yield field_name, RelatedObjectField, {} elif isinstance(field, wagtailcore_fields.StreamField): - yield StreamField(field_name) + yield field_name, StreamField, {} elif isinstance(field, TaggableManager): - yield TagsField(field_name) + yield field_name, TagsField, {} else: - yield AttrField(field_name) + yield field_name, ReadOnlyField, {} continue except models.fields.FieldDoesNotExist: @@ -104,14 +88,17 @@ def get_serializer_fields(model, fields): # Check attributes if hasattr(model, field_name): - yield AttrField(field_name) + yield field_name, ReadOnlyField, {} continue def get_api_data(obj, fields): serializer_fields = get_serializer_fields(type(obj), fields) - for field in serializer_fields: + for field_name, field_class, field_kwargs in serializer_fields: + field = field_class(**field_kwargs) + field.bind(field_name, None) + value = field.get_attribute(obj) if value is not None: