"""Adds and removes category relations on the database.""" from django.apps import apps from django.db import DatabaseError, connection, transaction from django.db.utils import OperationalError, ProgrammingError def table_exists(table_name): """ Check if a table exists in the database. """ pass def field_exists(app_name, model_name, field_name): """ Does the FK or M2M table exist in the database already? """ model = apps.get_model(app_name, model_name) table_name = model._meta.db_table cursor = connection.cursor() field_info = connection.introspection.get_table_description(cursor, table_name) field_names = [f.name for f in field_info] # Return True if the many to many table exists field = model._meta.get_field(field_name) if hasattr(field, "m2m_db_table"): m2m_table_name = field.m2m_db_table() try: m2m_field_info = connection.introspection.get_table_description(cursor, m2m_table_name) except DatabaseError: # Django >= 4.1 throws DatabaseError m2m_field_info = [] if m2m_field_info: return True return field_name in field_names def drop_field(app_name, model_name, field_name): """ Drop the given field from the app's model. """ app_config = apps.get_app_config(app_name) model = app_config.get_model(model_name) field = model._meta.get_field(field_name) with connection.schema_editor() as schema_editor: schema_editor.remove_field(model, field) def migrate_app(sender, *args, **kwargs): """ Migrate all models of this app registered. """ from .registration import registry if "app_config" not in kwargs: return app_config = kwargs["app_config"] app_name = app_config.label fields = [fld for fld in list(registry._field_registry.keys()) if fld.startswith(app_name)] sid = transaction.savepoint() for fld in fields: model_name, field_name = fld.split(".")[1:] if field_exists(app_name, model_name, field_name): continue model = app_config.get_model(model_name) try: with connection.schema_editor() as schema_editor: schema_editor.add_field(model, registry._field_registry[fld]) if sid: transaction.savepoint_commit(sid) # Django 4.1 with sqlite3 has for some reason started throwing OperationalError # instead of ProgrammingError, so we need to catch both. except (ProgrammingError, OperationalError): if sid: transaction.savepoint_rollback(sid) continue