diff --git a/AUTHORS.rst b/AUTHORS.rst index b601622..c903695 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -106,3 +106,4 @@ | Őry Máté | Nafees Anwar | meanmail +| Nicholas Prat diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 6109380..11ab314 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -263,7 +263,7 @@ class FieldInstanceTracker: if deferred_fields: fields = [ field for field in self.fields - if field not in deferred_fields + if self.field_map[field] not in deferred_fields ] else: fields = self.fields diff --git a/tests/models.py b/tests/models.py index 4d34505..8c6b046 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, ClassVar, TypeVar, overload +from typing import Any, ClassVar, Iterable, TypeVar, overload from django.db import models from django.db.models import Manager @@ -24,7 +24,7 @@ from model_utils.models import ( TimeStampedModel, UUIDModel, ) -from model_utils.tracker import FieldTracker, ModelTracker +from model_utils.tracker import FieldInstanceTracker, FieldTracker, ModelTracker from tests.fields import MutableField ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) @@ -280,6 +280,29 @@ class TrackedMultiple(models.Model): number_tracker = FieldTracker(fields=['number']) +class LoopDetectionFieldInstanceTracker(FieldInstanceTracker): + + def set_saved_fields(self, fields: Iterable[str] | None = None) -> None: + counter = getattr(self.__class__, '__loop_counter', 0) + if counter > 50: + raise AssertionError("Infinite Loop Detected!") + setattr(self.__class__, '__loop_counter', counter + 1) + super().set_saved_fields(fields) + + +class LoopDetectionFieldTracker(FieldTracker): + tracker_class = LoopDetectionFieldInstanceTracker + + +class TrackedProtectedSelfRefFK(models.Model): + fk = models.ForeignKey('Tracked', on_delete=models.PROTECT) + self_ref = models.ForeignKey('self', on_delete=models.SET_NULL, null=True, blank=True) + + tracker = LoopDetectionFieldTracker() + custom_tracker = LoopDetectionFieldTracker(fields=['fk_id', 'self_ref_id']) + custom_tracker_without_id = LoopDetectionFieldTracker(fields=['fk', 'self_ref']) + + class TrackedFileField(models.Model): some_file = models.FileField(upload_to='test_location') diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 81db1ec..f2f2127 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -2,9 +2,11 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any +import pytest from django.core.cache import cache from django.core.exceptions import FieldError from django.db import models +from django.db.models.deletion import ProtectedError from django.db.models.fields.files import FieldFile from django.test import TestCase @@ -25,6 +27,7 @@ from tests.models import ( TrackedMultiple, TrackedNonFieldAttr, TrackedNotDefault, + TrackedProtectedSelfRefFK, TrackerTimeStamped, ) @@ -570,6 +573,26 @@ class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): tracked_class = TrackedFK +class FieldTrackerProtectedForeignKeyTests(FieldTrackerMixin, TestCase): + """test case for issue #533 FieldTracker infinite recursion on a deleting object""" + + fk_class = Tracked + tracked_class = TrackedProtectedSelfRefFK + + def setUp(self) -> None: + self.old_fk = self.fk_class.objects.create(number=8) + self.instance = self.tracked_class.objects.create(fk=self.old_fk) + self.instance_2 = self.tracked_class.objects.create( + fk=self.old_fk, self_ref=self.instance + ) + self.instance.self_ref = self.instance_2 + self.instance.save() + + def test_fk_delete(self) -> None: + with pytest.raises(ProtectedError): + self.old_fk.delete() + + class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase): """Test that using `prefetch_related` on a tracked field does not raise a ValueError."""