diff --git a/wagtail/api/v2/endpoints.py b/wagtail/api/v2/endpoints.py index 73057e1f1..ba81ba2f8 100644 --- a/wagtail/api/v2/endpoints.py +++ b/wagtail/api/v2/endpoints.py @@ -94,57 +94,32 @@ class BaseAPIEndpoint(GenericViewSet): return Response(data, status=status.HTTP_400_BAD_REQUEST) return super(BaseAPIEndpoint, self).handle_exception(exc) + @classmethod + def _convert_api_fields(cls, fields): + return [field if isinstance(field, APIField) else APIField(field) + for field in fields] + @classmethod def get_body_fields(cls, model): - """ - This returns a list of field names that are allowed to - be used in the API (excluding the id field) - """ - fields = cls.body_fields[:] + return cls._convert_api_fields(cls.body_fields + list(getattr(model, 'api_fields', ()))) - if hasattr(model, 'api_fields'): - fields.extend([ - field.name if isinstance(field, APIField) else field - for field in model.api_fields - ]) - - return fields + @classmethod + def get_body_fields_names(cls, model): + return [field.name for field in cls.get_body_fields(model)] @classmethod def get_meta_fields(cls, model): - """ - This returns a list of field names that are allowed to - be used in the meta section in the API (excluding type and detail_url). - """ - meta_fields = cls.meta_fields[:] + return cls._convert_api_fields(cls.meta_fields + list(getattr(model, 'api_meta_fields', ()))) - if hasattr(model, 'api_meta_fields'): - meta_fields.extend([ - field.name if isinstance(field, APIField) else field - for field in model.api_meta_fields - ]) - - return meta_fields + @classmethod + def get_meta_fields_names(cls, model): + return [field.name for field in cls.get_meta_fields(model)] @classmethod def get_field_serializer_overrides(cls, model): - serializers = {} - - if hasattr(model, 'api_fields'): - serializers.update({ - field.name: field.serializer - for field in model.api_fields - if isinstance(field, APIField) and field.serializer is not None - }) - - if hasattr(model, 'api_meta_fields'): - serializers.update({ - field.name: field.serializer - for field in model.api_meta_fields - if isinstance(field, APIField) and field.serializer is not None - }) - - return serializers + return {field.name: field.serializer + for field in cls.get_body_fields(model) + cls.get_meta_fields(model) + if field.serializer is not None} @classmethod def get_available_fields(cls, model, db_fields_only=False): @@ -156,7 +131,7 @@ class BaseAPIEndpoint(GenericViewSet): an underlying column in the database (eg, type/detail_url and any custom fields that are callables) """ - fields = cls.get_body_fields(model) + cls.get_meta_fields(model) + fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model) if db_fields_only: # Get list of available database fields then remove any fields in our @@ -199,8 +174,8 @@ class BaseAPIEndpoint(GenericViewSet): @classmethod def _get_serializer_class(cls, router, model, fields_config, show_details=False, nested=False): # Get all available fields - body_fields = cls.get_body_fields(model) - meta_fields = cls.get_meta_fields(model) + body_fields = cls.get_body_fields_names(model) + meta_fields = cls.get_meta_fields_names(model) all_fields = body_fields + meta_fields # Remove any duplicates