diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index d7f9a5e90..f9828f3a2 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -6,11 +6,6 @@ from django.utils.six import text_type from rest_framework import renderers -from taggit.managers import _TaggableManager -from taggit.models import Tag - -from wagtail.wagtailcore.blocks import StreamValue - from .utils import URLPath, ObjectDetailURL, get_base_url @@ -35,11 +30,7 @@ class WagtailJSONRenderer(renderers.BaseRenderer): class WagtailAPIJSONEncoder(DjangoJSONEncoder): def default(self, o): - if isinstance(o, _TaggableManager): - return list(o.all()) - elif isinstance(o, Tag): - return o.name - elif isinstance(o, URLPath): + if isinstance(o, URLPath): return get_full_url(request, o.path) elif isinstance(o, ObjectDetailURL): detail_view = find_model_detail_view(o.model, endpoints) @@ -48,8 +39,6 @@ class WagtailJSONRenderer(renderers.BaseRenderer): return get_full_url(request, reverse(detail_view, args=(o.pk, ))) else: return None - elif isinstance(o, StreamValue): - return o.stream_block.get_prep_value(o) else: return super(WagtailAPIJSONEncoder, self).default(o) diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py index a28f6f2f5..c0f26101c 100644 --- a/wagtail/contrib/wagtailapi/serializers.py +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -7,65 +7,119 @@ from django.utils.encoding import force_text from modelcluster.models import get_all_child_relations +from taggit.managers import TaggableManager + from rest_framework.serializers import BaseSerializer from wagtail.utils.compat import get_related_model from wagtail.wagtailcore.models import Page +from wagtail.wagtailcore import fields as wagtailcore_fields from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site -def get_api_data(obj, fields): +class Field(object): + def __init__(self, field_name): + self.field_name = field_name + + def get_attribute(self, instance): + return getattr(instance, self.field_name) + + def to_representation(self, value): + return value + + +class AttrField(Field): + def to_representation(self, value): + return force_text(value, strings_only=True) + + +class ChildRelationField(Field): + def __init__(self, field_name, fields): + self.field_name = field_name + self.fields = fields + + def to_representation(self, value): + return [ + dict(get_api_data(child_object, self.fields)) + for child_object in value.all() + ] + + +class RelatedObjectField(Field): + def to_representation(self, value): + model = type(value) + + return OrderedDict([ + ('id', value.pk), + ('meta', OrderedDict([ + ('type', model._meta.app_label + '.' + model.__name__), + ('detail_url', ObjectDetailURL(model, value.pk)), + ])), + ]) + + +class StreamField(Field): + def to_representation(self, value): + return value.stream_block.get_prep_value(value) + + +class TagsField(Field): + def to_representation(self, value): + return list(value.all().values_list('name', flat=True)) + + +def get_serializer_fields(model, fields): # Find any child relations (pages only) child_relations = {} - if isinstance(obj, Page): + if issubclass(model, Page): child_relations = { child_relation.field.rel.related_name: get_related_model(child_relation) - for child_relation in get_all_child_relations(type(obj)) + for child_relation in get_all_child_relations(model) } # Loop through fields for field_name in fields: # Check child relations if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): - yield field_name, [ - dict(get_api_data(child_object, child_relations[field_name].api_fields)) - for child_object in getattr(obj, field_name).all() - ] + yield ChildRelationField(field_name, child_relations[field_name].api_fields) continue # Check django fields try: - field = obj._meta.get_field(field_name) + field = model._meta.get_field(field_name) if field.rel and isinstance(field.rel, models.ManyToOneRel): - # Foreign key - val = field._get_val_from_obj(obj) - - if val: - yield field_name, OrderedDict([ - ('id', field._get_val_from_obj(obj)), - ('meta', OrderedDict([ - ('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__), - ('detail_url', ObjectDetailURL(field.rel.to, val)), - ])), - ]) - else: - yield field_name, None + yield RelatedObjectField(field_name) + elif isinstance(field, wagtailcore_fields.StreamField): + yield StreamField(field_name) + elif isinstance(field, TaggableManager): + yield TagsField(field_name) else: - yield field_name, field._get_val_from_obj(obj) + yield AttrField(field_name) continue except models.fields.FieldDoesNotExist: pass # Check attributes - if hasattr(obj, field_name): - value = getattr(obj, field_name) - yield field_name, force_text(value, strings_only=True) + if hasattr(model, field_name): + yield AttrField(field_name) continue +def get_api_data(obj, fields): + serializer_fields = get_serializer_fields(type(obj), fields) + + for field in serializer_fields: + value = field.get_attribute(obj) + + if value is not None: + yield field.field_name, field.to_representation(value) + else: + yield field.field_name, None + + class WagtailSerializer(BaseSerializer): def to_representation(self, instance): request = self.context['request']