diff --git a/categories/templatetags/category_tags.py b/categories/templatetags/category_tags.py
index 88ed1f1..1331cda 100644
--- a/categories/templatetags/category_tags.py
+++ b/categories/templatetags/category_tags.py
@@ -1,18 +1,17 @@
from django import template
from django.db.models import get_model
-from django.template import (Node, TemplateSyntaxError, Variable,
- VariableDoesNotExist)
+from django.template import (Node, TemplateSyntaxError, VariableDoesNotExist,
+ FilterExpression)
from categories.base import CategoryBase
from categories.models import Category
from mptt.utils import drilldown_tree_for_node
-from mptt.templatetags.mptt_tags import (tree_path, tree_info, recursetree,
+from mptt.templatetags.mptt_tags import (tree_path, tree_info, RecurseTreeNode,
full_tree_for_model)
register = template.Library()
register.filter("category_path", tree_path)
register.filter(tree_info)
-register.tag(recursetree)
register.tag("full_tree_for_category", full_tree_for_model)
@@ -72,7 +71,7 @@ def get_category(category_string, model=Category):
class CategoryDrillDownNode(template.Node):
def __init__(self, category, varname, model):
- self.category = template.Variable(category)
+ self.category = category
self.varname = varname
self.model = model
@@ -133,7 +132,7 @@ def get_category_drilldown(parser, token):
if bits[2] == 'using':
varname = bits[5].strip("'\"")
model = bits[3].strip("'\"")
- category = bits[1]
+ category = FilterExpression(bits[1], parser)
return CategoryDrillDownNode(category, varname, model)
@@ -282,12 +281,12 @@ class LatestObjectsNode(Node):
"""
Get latest objects of app_label.model_name
"""
- self.category = Variable(category)
- self.app_label = Variable(app_label)
- self.model_name = Variable(model_name)
- self.set_name = Variable(set_name)
- self.date_field = Variable(date_field)
- self.num = Variable(num)
+ self.category = category
+ self.app_label = app_label
+ self.model_name = model_name
+ self.set_name = set_name
+ self.date_field = date_field
+ self.num = num
self.var_name = var_name
def render(self, context):
@@ -323,19 +322,19 @@ def do_get_latest_objects_by_category(parser, token):
raise TemplateSyntaxError("%s tag shoud be in the form: %s" % (bits[0], proper_form))
if len(bits) > 9:
raise TemplateSyntaxError("%s tag shoud be in the form: %s" % (bits[0], proper_form))
- category = bits[1]
- app_label = bits[2]
- model_name = bits[3]
- set_name = bits[4]
+ category = FilterExpression(bits[1], parser)
+ app_label = FilterExpression(bits[2], parser)
+ model_name = FilterExpression(bits[3], parser)
+ set_name = FilterExpression(bits[4], parser)
var_name = bits[-1]
if bits[5] != 'as':
- date_field = bits[5]
+ date_field = FilterExpression(bits[5], parser)
else:
- date_field = None
+ date_field = FilterExpression(None, parser)
if bits[6] != 'as':
- num = bits[6]
+ num = FilterExpression(bits[6], parser)
else:
- num = None
+ num = FilterExpression(None, parser)
return LatestObjectsNode(var_name, category, app_label, model_name, set_name,
date_field, num)
@@ -376,3 +375,35 @@ def tree_queryset(value):
qs = qs.distinct()
return qs
+
+
+@register.tag
+def recursetree(parser, token):
+ """
+ Iterates over the nodes in the tree, and renders the contained block for each node.
+ This tag will recursively render children into the template variable {{ children }}.
+ Only one database query is required (children are cached for the whole tree)
+
+ Usage:
+
+ {% recursetree nodes %}
+ -
+ {{ node.name }}
+ {% if not node.is_leaf_node %}
+
+ {% endif %}
+
+ {% endrecursetree %}
+
+ """
+ bits = token.contents.split()
+ if len(bits) != 2:
+ raise template.TemplateSyntaxError('%s tag requires a queryset' % bits[0])
+ queryset_var = FilterExpression(bits[1], parser)
+
+ template_nodes = parser.parse(('endrecursetree',))
+ parser.delete_first_token()
+
+ return RecurseTreeNode(template_nodes, queryset_var)
diff --git a/categories/tests/category_import.py b/categories/tests/category_import.py
index b2886e1..6923747 100644
--- a/categories/tests/category_import.py
+++ b/categories/tests/category_import.py
@@ -1,64 +1,62 @@
# test spaces in hierarchy
# test tabs in hierarchy
# test mixed
-import unittest, os
+import unittest
+import os
from categories.models import Category
from categories.management.commands.import_categories import Command
from django.core.management.base import CommandError
+
class CategoryImportTest(unittest.TestCase):
def setUp(self):
pass
-
+
def _import_file(self, filename):
root_cats = ['Category 1', 'Category 2', 'Category 3']
testfile = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'fixtures', filename))
cmd = Command()
cmd.execute(testfile)
roots = Category.tree.root_nodes()
-
+
self.assertEqual(len(roots), 3)
for item in roots:
assert item.name in root_cats
-
+
cat2 = Category.objects.get(name='Category 2')
cat21 = cat2.children.all()[0]
self.assertEqual(cat21.name, 'Category 2-1')
cat211 = cat21.children.all()[0]
self.assertEqual(cat211.name, 'Category 2-1-1')
-
-
+
def testImportSpaceDelimited(self):
Category.objects.all().delete()
self._import_file('test_category_spaces.txt')
-
+
items = Category.objects.all()
-
+
self.assertEqual(items[0].name, 'Category 1')
self.assertEqual(items[1].name, 'Category 1-1')
self.assertEqual(items[2].name, 'Category 1-2')
-
-
+
def testImportTabDelimited(self):
Category.objects.all().delete()
self._import_file('test_category_tabs.txt')
-
+
items = Category.objects.all()
-
+
self.assertEqual(items[0].name, 'Category 1')
self.assertEqual(items[1].name, 'Category 1-1')
self.assertEqual(items[2].name, 'Category 1-2')
-
-
+
def testMixingTabsSpaces(self):
"""
Should raise an exception.
"""
- string1 = ["cat1"," cat1-1", "\tcat1-2-FAIL!",""]
- string2 = ["cat1","\tcat1-1"," cat1-2-FAIL!",""]
+ string1 = ["cat1", " cat1-1", "\tcat1-2-FAIL!", ""]
+ string2 = ["cat1", "\tcat1-1", " cat1-2-FAIL!", ""]
cmd = Command()
-
+
# raise Exception
self.assertRaises(CommandError, cmd.parse_lines, string1)
self.assertRaises(CommandError, cmd.parse_lines, string2)
-
\ No newline at end of file
diff --git a/categories/tests/manager.py b/categories/tests/manager.py
index f8c0524..2d46a93 100644
--- a/categories/tests/manager.py
+++ b/categories/tests/manager.py
@@ -1,23 +1,22 @@
# test active returns only active items
-import unittest, os
+import unittest
from categories.models import Category
-from categories.management.commands.import_categories import Command
-from django.core.management.base import CommandError
+
class CategoryManagerTest(unittest.TestCase):
def setUp(self):
pass
-
+
def testActive(self):
"""
Should raise an exception.
"""
all_count = Category.objects.all().count()
self.assertEqual(Category.objects.active().count(), all_count)
-
+
cat1 = Category.objects.get(name='Category 1')
cat1.active = False
cat1.save()
-
+
active_count = all_count - cat1.get_descendants(True).count()
self.assertEqual(Category.objects.active().count(), active_count)
diff --git a/categories/tests/templatetags.py b/categories/tests/templatetags.py
index bf6bf71..b70488a 100644
--- a/categories/tests/templatetags.py
+++ b/categories/tests/templatetags.py
@@ -58,11 +58,10 @@ class CategoryTagsTest(TestCase):
self.assertEqual(resp, expected_resp)
# recursetree
- expected_resp = u''
- ctxt = {'nodes': Category.objects.filter(name__in=("Worldbeat", "Urban cowboy"))}
+ expected_resp = u''
+ ctxt = {'nodes': Category.objects.filter(name__in=("Worldbeat", "Urban Cowboy"))}
resp = self.render_template('{% load category_tags %}'
- '{% recursetree nodes|treeinfo %}- {{ node.name }}'
- '{% if not node.is_leaf_node %}
{{ children }}'
+ '{% recursetree nodes|tree_queryset %}- {{ node.name }}'
+ '{% if not node.is_leaf_node %}{% endif %}
{% endrecursetree %}
', ctxt)
self.assertEqual(resp, expected_resp)
-