Refactored get_api_data into classes

This commit is contained in:
Karl Hobley 2015-08-03 11:54:45 +01:00
parent 30408252ff
commit 4523f8b5cc
2 changed files with 80 additions and 37 deletions

View file

@ -6,11 +6,6 @@ from django.utils.six import text_type
from rest_framework import renderers
from taggit.managers import _TaggableManager
from taggit.models import Tag
from wagtail.wagtailcore.blocks import StreamValue
from .utils import URLPath, ObjectDetailURL, get_base_url
@ -35,11 +30,7 @@ class WagtailJSONRenderer(renderers.BaseRenderer):
class WagtailAPIJSONEncoder(DjangoJSONEncoder):
def default(self, o):
if isinstance(o, _TaggableManager):
return list(o.all())
elif isinstance(o, Tag):
return o.name
elif isinstance(o, URLPath):
if isinstance(o, URLPath):
return get_full_url(request, o.path)
elif isinstance(o, ObjectDetailURL):
detail_view = find_model_detail_view(o.model, endpoints)
@ -48,8 +39,6 @@ class WagtailJSONRenderer(renderers.BaseRenderer):
return get_full_url(request, reverse(detail_view, args=(o.pk, )))
else:
return None
elif isinstance(o, StreamValue):
return o.stream_block.get_prep_value(o)
else:
return super(WagtailAPIJSONEncoder, self).default(o)

View file

@ -7,65 +7,119 @@ from django.utils.encoding import force_text
from modelcluster.models import get_all_child_relations
from taggit.managers import TaggableManager
from rest_framework.serializers import BaseSerializer
from wagtail.utils.compat import get_related_model
from wagtail.wagtailcore.models import Page
from wagtail.wagtailcore import fields as wagtailcore_fields
from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site
def get_api_data(obj, fields):
class Field(object):
def __init__(self, field_name):
self.field_name = field_name
def get_attribute(self, instance):
return getattr(instance, self.field_name)
def to_representation(self, value):
return value
class AttrField(Field):
def to_representation(self, value):
return force_text(value, strings_only=True)
class ChildRelationField(Field):
def __init__(self, field_name, fields):
self.field_name = field_name
self.fields = fields
def to_representation(self, value):
return [
dict(get_api_data(child_object, self.fields))
for child_object in value.all()
]
class RelatedObjectField(Field):
def to_representation(self, value):
model = type(value)
return OrderedDict([
('id', value.pk),
('meta', OrderedDict([
('type', model._meta.app_label + '.' + model.__name__),
('detail_url', ObjectDetailURL(model, value.pk)),
])),
])
class StreamField(Field):
def to_representation(self, value):
return value.stream_block.get_prep_value(value)
class TagsField(Field):
def to_representation(self, value):
return list(value.all().values_list('name', flat=True))
def get_serializer_fields(model, fields):
# Find any child relations (pages only)
child_relations = {}
if isinstance(obj, Page):
if issubclass(model, Page):
child_relations = {
child_relation.field.rel.related_name: get_related_model(child_relation)
for child_relation in get_all_child_relations(type(obj))
for child_relation in get_all_child_relations(model)
}
# Loop through fields
for field_name in fields:
# Check child relations
if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'):
yield field_name, [
dict(get_api_data(child_object, child_relations[field_name].api_fields))
for child_object in getattr(obj, field_name).all()
]
yield ChildRelationField(field_name, child_relations[field_name].api_fields)
continue
# Check django fields
try:
field = obj._meta.get_field(field_name)
field = model._meta.get_field(field_name)
if field.rel and isinstance(field.rel, models.ManyToOneRel):
# Foreign key
val = field._get_val_from_obj(obj)
if val:
yield field_name, OrderedDict([
('id', field._get_val_from_obj(obj)),
('meta', OrderedDict([
('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__),
('detail_url', ObjectDetailURL(field.rel.to, val)),
])),
])
else:
yield field_name, None
yield RelatedObjectField(field_name)
elif isinstance(field, wagtailcore_fields.StreamField):
yield StreamField(field_name)
elif isinstance(field, TaggableManager):
yield TagsField(field_name)
else:
yield field_name, field._get_val_from_obj(obj)
yield AttrField(field_name)
continue
except models.fields.FieldDoesNotExist:
pass
# Check attributes
if hasattr(obj, field_name):
value = getattr(obj, field_name)
yield field_name, force_text(value, strings_only=True)
if hasattr(model, field_name):
yield AttrField(field_name)
continue
def get_api_data(obj, fields):
serializer_fields = get_serializer_fields(type(obj), fields)
for field in serializer_fields:
value = field.get_attribute(obj)
if value is not None:
yield field.field_name, field.to_representation(value)
else:
yield field.field_name, None
class WagtailSerializer(BaseSerializer):
def to_representation(self, instance):
request = self.context['request']