From 8099384391f206175013e35e0a201cb02792946d Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Thu, 31 Aug 2023 22:16:25 -0700 Subject: [PATCH] Add foreign key from embeddings to collections, refs #185 --- llm/embeddings_migrations.py | 5 +++++ tests/test_migrate.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/llm/embeddings_migrations.py b/llm/embeddings_migrations.py index 0cc9f71..7e1c590 100644 --- a/llm/embeddings_migrations.py +++ b/llm/embeddings_migrations.py @@ -17,3 +17,8 @@ def m001_create_tables(db): }, pk=("collection_id", "id"), ) + + +@embeddings_migrations() +def m002_foreign_key(db): + db["embeddings"].add_foreign_key("collection_id", "collections", "id") diff --git a/tests/test_migrate.py b/tests/test_migrate.py index a79b137..89deaae 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,4 +1,5 @@ from llm.migrations import migrate +from llm.embeddings_migrations import embeddings_migrations import pytest import sqlite_utils @@ -78,3 +79,18 @@ def test_migrations_with_legacy_alter_table(): db = sqlite_utils.Database(memory=True) db.execute("pragma legacy_alter_table=on") migrate(db) + + +def test_migrations_for_embeddings(): + db = sqlite_utils.Database(memory=True) + embeddings_migrations.apply(db) + assert db["collections"].columns_dict == {"id": int, "name": str, "model": str} + assert db["embeddings"].columns_dict == { + "collection_id": int, + "id": str, + "embedding": bytes, + "content": str, + "metadata": str, + } + assert db["embeddings"].foreign_keys[0].column == "collection_id" + assert db["embeddings"].foreign_keys[0].other_table == "collections"