Use rest frameworks fields API

This commit is contained in:
Karl Hobley 2015-08-03 13:49:13 +01:00
parent 4523f8b5cc
commit 85a4209893

View file

@ -3,13 +3,13 @@ from __future__ import absolute_import
from collections import OrderedDict
from django.db import models
from django.utils.encoding import force_text
from modelcluster.models import get_all_child_relations
from taggit.managers import TaggableManager
from rest_framework.serializers import BaseSerializer
from rest_framework.fields import Field, ReadOnlyField
from wagtail.utils.compat import get_related_model
from wagtail.wagtailcore.models import Page
@ -18,30 +18,14 @@ from wagtail.wagtailcore import fields as wagtailcore_fields
from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site
class Field(object):
def __init__(self, field_name):
self.field_name = field_name
def get_attribute(self, instance):
return getattr(instance, self.field_name)
def to_representation(self, value):
return value
class AttrField(Field):
def to_representation(self, value):
return force_text(value, strings_only=True)
class ChildRelationField(Field):
def __init__(self, field_name, fields):
self.field_name = field_name
self.fields = fields
def __init__(self, *args, **kwargs):
self.child_fields = kwargs.pop('child_fields')
super(ChildRelationField, self).__init__(*args, **kwargs)
def to_representation(self, value):
return [
dict(get_api_data(child_object, self.fields))
dict(get_api_data(child_object, self.child_fields))
for child_object in value.all()
]
@ -82,7 +66,7 @@ def get_serializer_fields(model, fields):
for field_name in fields:
# Check child relations
if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'):
yield ChildRelationField(field_name, child_relations[field_name].api_fields)
yield field_name, ChildRelationField, {'child_fields': child_relations[field_name].api_fields}
continue
# Check django fields
@ -90,13 +74,13 @@ def get_serializer_fields(model, fields):
field = model._meta.get_field(field_name)
if field.rel and isinstance(field.rel, models.ManyToOneRel):
yield RelatedObjectField(field_name)
yield field_name, RelatedObjectField, {}
elif isinstance(field, wagtailcore_fields.StreamField):
yield StreamField(field_name)
yield field_name, StreamField, {}
elif isinstance(field, TaggableManager):
yield TagsField(field_name)
yield field_name, TagsField, {}
else:
yield AttrField(field_name)
yield field_name, ReadOnlyField, {}
continue
except models.fields.FieldDoesNotExist:
@ -104,14 +88,17 @@ def get_serializer_fields(model, fields):
# Check attributes
if hasattr(model, field_name):
yield AttrField(field_name)
yield field_name, ReadOnlyField, {}
continue
def get_api_data(obj, fields):
serializer_fields = get_serializer_fields(type(obj), fields)
for field in serializer_fields:
for field_name, field_class, field_kwargs in serializer_fields:
field = field_class(**field_kwargs)
field.bind(field_name, None)
value = field.get_attribute(obj)
if value is not None: