diff --git a/categories/templatetags/category_tags.py b/categories/templatetags/category_tags.py index 3a8c123..f5ff719 100644 --- a/categories/templatetags/category_tags.py +++ b/categories/templatetags/category_tags.py @@ -1,46 +1,66 @@ from django import template from categories.models import Category from mptt.utils import drilldown_tree_for_node +from mptt.templatetags.mptt_tags import tree_path, tree_info -class CategoryNode(template.Node): +register = template.Library() + +register.filter("category_path", tree_path) +register.filter(tree_info) + +def get_category(category_string): + """ + Convert a string, including a path, and return the Category object + """ + if category.startswith('"') and category.endswith('"'): + category = category_string[1:-1] + else: + category = category_string + + if category.startswith('/'): + category = category[1:] + if category.endswith('/'): + category = category[:-1] + + cat_list = category.split('/') + if len(cat_list) == 0: + return None + try: + categories = Category.objects.filter(name = cat_list[-1], level=len(cat_list)-1) + if len(cat_list) == 1 and len(categories) > 1: + return None + # If there is only one, use it. If there is more than one, check + # if the parent matches the parent passed in the string + if len(categories) == 1: + return categories[0] + else: + for item in categories: + if item.parent.name == cat_list[-2]: + return item + except Category.DoesNotExist: + return None + + +class CategoryDrillDownNode(template.Node): def __init__(self, category, varname): - if category.startswith('"') and category.endswith('"'): - self.category = category[1:-1] - else: - self.category = category - - if self.category.startswith('/'): - self.category = self.category[1:] - else: - self.category = category - if self.category.endswith('/'): - self.category = self.category[:-1] - + self.category = get_category(category) self.varname = varname def render(self, context): try: - cat_list = self.category.split('/') - categories = Category.objects.filter(name = cat_list[-1], level=len(cat_list)-1) - - if len(cat_list) == 1 and len(categories) > 1: - context[self.varname] = None - return '' - if len(categories) == 1: - context[self.varname] = drilldown_tree_for_node(categories[0]) + if self.category is not None: + context[self.varname] = drilldown_tree_for_node(self.category) else: - for item in categories: - if item.parent.name == cat_list[-2]: - context[self.varname] = drilldown_tree_for_node(item) + context[self.varname] = None except Category.DoesNotExist: context[self.varname] = None return '' -def get_category(parser, token): +def get_category_drilldown(parser, token): """ - Retrieves the specified category tree. + Retrieves the specified category, it's ancestors and its children as an iterable list. Syntax:: @@ -48,8 +68,12 @@ def get_category(parser, token): Example:: - {% get_category "/Rock" as rock %} - {% get_category "/Rock/C-Rock" as crock %} + {% get_category "/Grandparent/Parent" as family %} + + Returns an iterable with:: + + Grandparent, Parent, Child 1, Child 2, Child n + """ bits = token.contents.split() error_str = '%(tagname)s tag should be in the format {%% %(tagname)s ' \ @@ -58,7 +82,24 @@ def get_category(parser, token): raise template.TemplateSyntaxError, error_str % {'tagname': bits[0]} category = bits[1] varname = bits[3] - return CategoryNode(category, varname) + return CategoryDrillDownNode(category, varname) -register = template.Library() -register.tag(get_category) +register.tag(get_category_drilldown) + +@register.inclusion_tag('categories/ancestors_ul.html') +def display_path_as_ul(category): + """ + Display the category with ancestors, but no children. + + Example:: + + {% display_path_as_ul "/Grandparent/Parent" %} + + Returns:: + +