diff --git a/auditlog/management/commands/auditlogflush.py b/auditlog/management/commands/auditlogflush.py index e57ef2c..9789709 100644 --- a/auditlog/management/commands/auditlogflush.py +++ b/auditlog/management/commands/auditlogflush.py @@ -1,6 +1,7 @@ import datetime from django.core.management.base import BaseCommand +from django.db import connection from auditlog.models import LogEntry @@ -25,11 +26,24 @@ class Command(BaseCommand): dest="before_date", type=datetime.date.fromisoformat, ) + parser.add_argument( + "-t", + "--truncate", + action="store_true", + default=None, + help="Truncate log entry table.", + dest="truncate", + ) def handle(self, *args, **options): answer = options["yes"] + truncate = options["truncate"] before = options["before_date"] - + if truncate and before: + self.stdout.write( + "Truncate deletes all log entries and can not be passed with before-date." + ) + return if answer is None: warning_message = ( "This action will clear all log entries from the database." @@ -42,11 +56,45 @@ class Command(BaseCommand): ) answer = response == "y" - if answer: + if not answer: + self.stdout.write("Aborted.") + return + + if not truncate: entries = LogEntry.objects.all() if before is not None: entries = entries.filter(timestamp__date__lt=before) count, _ = entries.delete() self.stdout.write("Deleted %d objects." % count) + return + else: - self.stdout.write("Aborted.") + database_vendor = connection.vendor + database_display_name = connection.display_name + table_name = LogEntry._meta.db_table + truncate_query = TruncateQuery(database_vendor, table_name) + if truncate_query.is_not_supported: + self.stdout.write( + "Database %s does not support truncate statement." + % database_display_name + ) + return + with connection.cursor() as cursor: + query = truncate_query.to_sql() + cursor.execute(query) + self.stdout.write("Truncated log entry table.") + + +class TruncateQuery: + SUPPORTED_VENDORS = ("postgresql", "mysql", "sqlite", "oracle", "microsoft") + + def __init__(self, database_vendor: str, table_name: str): + self.database_vendor = database_vendor + self.table_name = table_name + + @property + def is_not_supported(self): + return self.database_vendor not in self.SUPPORTED_VENDORS + + def to_sql(self): + return f"TRUNCATE TABLE {self.table_name};"