From 0b77112bf3a09fcda604d2f386f8922944f46632 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Mon, 3 Aug 2015 17:09:45 +0100 Subject: [PATCH] Moved get_serializer_field logic into base serializer class --- wagtail/contrib/wagtailapi/serializers.py | 83 +++++++++++------------ 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py index 0150de668..62f6ba197 100644 --- a/wagtail/contrib/wagtailapi/serializers.py +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -2,14 +2,13 @@ from __future__ import absolute_import from collections import OrderedDict -from django.db import models - from modelcluster.models import get_all_child_relations -from taggit.managers import TaggableManager +from taggit.managers import _TaggableManager -from rest_framework.serializers import BaseSerializer, Serializer -from rest_framework.fields import Field, ReadOnlyField +from rest_framework import serializers +from rest_framework.fields import Field +from rest_framework.relations import RelatedField from wagtail.utils.compat import get_related_model from wagtail.wagtailcore.models import Page @@ -33,7 +32,7 @@ class ChildRelationField(Field): ] -class RelatedObjectField(Field): +class RelatedObjectField(RelatedField): def to_representation(self, value): model = type(value) @@ -56,53 +55,51 @@ class TagsField(Field): return list(value.all().values_list('name', flat=True)) -def get_serializer_field(model, field_name): - # Find any child relations (pages only) - child_relations = {} - if issubclass(model, Page): - child_relations = { - child_relation.field.rel.related_name: get_related_model(child_relation) - for child_relation in get_all_child_relations(model) - } +class BaseSerializer(serializers.ModelSerializer): + # Add StreamField to serializer_field_mapping + serializer_field_mapping = serializers.ModelSerializer.serializer_field_mapping.copy() + serializer_field_mapping.update({ + wagtailcore_fields.StreamField: StreamField, + }) + serializer_related_field = RelatedObjectField - # Check child relations - if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): - return ChildRelationField, {'child_fields': child_relations[field_name].api_fields} - - # Check django fields - try: - field = model._meta.get_field(field_name) - - if field.rel and isinstance(field.rel, models.ManyToOneRel): - return RelatedObjectField, {} - elif isinstance(field, wagtailcore_fields.StreamField): - return StreamField, {} - elif isinstance(field, TaggableManager): + def build_property_field(self, field_name, model_class): + # TaggableManager is not a Django field so it gets treated as a property + field = getattr(model_class, field_name) + if isinstance(field, _TaggableManager): return TagsField, {} - else: - return ReadOnlyField, {} - except models.fields.FieldDoesNotExist: - pass + return super(BaseSerializer, self).build_property_field(field_name, model_class) - # Check attributes - if hasattr(model, field_name): - return ReadOnlyField, {} + def build_relational_field(self, field_name, relation_info): + if relation_info.to_many: + # Find child relations (pages only) + model = getattr(self.Meta, 'model') + child_relations = {} + if issubclass(model, Page): + child_relations = { + child_relation.field.rel.related_name: get_related_model(child_relation) + for child_relation in get_all_child_relations(model) + } + + # Check child relations + if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): + return ChildRelationField, {'child_fields': child_relations[field_name].api_fields} + + return super(BaseSerializer, self).build_relational_field(field_name, relation_info) -def get_serializer_class(model, fields): - serializer_fields = [ - (field_name, get_serializer_field(model, field_name)) - for field_name in fields - ] +def get_serializer_class(model_, fields_): + class Meta: + model = model_ + fields = fields_ - return type(model.__name__ + 'Serializer', (Serializer, ), { - field_name: field_class(**field_kwargs) - for field_name, (field_class, field_kwargs) in serializer_fields + return type(model_.__name__ + 'Serializer', (BaseSerializer, ), { + 'Meta': Meta }) -class WagtailSerializer(BaseSerializer): +class WagtailSerializer(serializers.BaseSerializer): def to_representation(self, instance): request = self.context['request'] fields = self.context.get('fields', frozenset())