diff --git a/categories/management/commands/import_categories.py b/categories/management/commands/import_categories.py index 76dc605..8ff419c 100644 --- a/categories/management/commands/import_categories.py +++ b/categories/management/commands/import_categories.py @@ -1,6 +1,7 @@ from django.core.management.base import BaseCommand, CommandError from categories.models import Category from django.template.defaultfilters import slugify +from django.db import transaction class Command(BaseCommand): """Import category trees from a file.""" @@ -22,17 +23,23 @@ class Command(BaseCommand): else: return ' ' * indent_amt - + @transaction.commit_on_success def make_category(self, string, parent=None, order=1): """ Make and save a category object from a string """ - return Category.objects.create( + cat = Category( name=string.strip(), slug=slugify(string.strip())[:49], - parent=parent, + #parent=parent, order=order ) + cat._tree_manager.insert_node(cat, parent, 'last-child', True) + cat.save() + if parent: + parent.rght = cat.rght + 1 + parent.save() + return cat def parse_lines(self, lines): """ @@ -60,6 +67,7 @@ class Command(BaseCommand): else: # We are back to a zero level, so reset the whole thing current_parents = {0: self.make_category(line)} + current_parents[0]._tree_manager.rebuild() def handle(self, *file_paths, **options): """ diff --git a/categories/tests/__init__.py b/categories/tests/__init__.py index 6934d61..b144d87 100644 --- a/categories/tests/__init__.py +++ b/categories/tests/__init__.py @@ -1,4 +1,5 @@ from categories.tests.category_import import * from categories.tests.templatetags import * +from categories.tests.manager import * __fixtures__ = ['categories.json'] diff --git a/categories/tests/category_import.py b/categories/tests/category_import.py index 94801ce..b2886e1 100644 --- a/categories/tests/category_import.py +++ b/categories/tests/category_import.py @@ -31,11 +31,23 @@ class CategoryImportTest(unittest.TestCase): def testImportSpaceDelimited(self): Category.objects.all().delete() self._import_file('test_category_spaces.txt') + + items = Category.objects.all() + + self.assertEqual(items[0].name, 'Category 1') + self.assertEqual(items[1].name, 'Category 1-1') + self.assertEqual(items[2].name, 'Category 1-2') def testImportTabDelimited(self): Category.objects.all().delete() self._import_file('test_category_tabs.txt') + + items = Category.objects.all() + + self.assertEqual(items[0].name, 'Category 1') + self.assertEqual(items[1].name, 'Category 1-1') + self.assertEqual(items[2].name, 'Category 1-2') def testMixingTabsSpaces(self):