Collection now defaults to in-memory DB, closes #213

This commit is contained in:
Simon Willison 2023-09-02 15:43:10 -07:00
parent 8bdaca1f31
commit 51488c579b
5 changed files with 16 additions and 13 deletions

View file

@ -30,13 +30,17 @@ To work with embeddings in this way you will need an instance of a [sqlite-utils
import sqlite_utils
import llm
db = sqlite_utils.Database("my-embeddings.db")
# Pass model_id= to specify a model for the collection
collection = llm.Collection(db, "entries", model_id="ada-002")
# This collection will use an in-memory database that will be
# discarded when the Python process exits
collection = llm.Collection("entries", model_id="ada-002")
# Or you can pass a model directly using model=
# Or you can persist the database to disk like this:
db = sqlite_utils.Database("my-embeddings.db")
collection = llm.Collection("entries", db, model_id="ada-002")
# You can pass a model directly using model= instead of model_id=
embedding_model = llm.get_embedding_model("ada-002")
collection = llm.Collection(db, "entries", model=embedding_model)
collection = llm.Collection("entries", db, model=embedding_model)
```
If the collection already exists in the database you can omit the `model` or `model_id` argument - the model ID will be read from the `collections` table.

View file

@ -910,7 +910,7 @@ def embed(collection, id, input, model, store, database, content, format_):
model_obj = None
if collection:
db = get_db()
collection_obj = Collection(db, collection, model_id=model)
collection_obj = Collection(collection, db, model_id=model)
model_obj = collection_obj.model()
if model_obj is None:
@ -995,7 +995,7 @@ def similar(collection, id, input, content, number, database):
raise click.ClickException("No embeddings table found in database")
try:
collection_obj = Collection(db, collection, create=False)
collection_obj = Collection(collection, db, create=False)
except Collection.DoesNotExist:
raise click.ClickException("Collection does not exist")

View file

@ -24,8 +24,8 @@ class Collection:
def __init__(
self,
db: Database,
name: str,
db: Optional[Database] = None,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
@ -48,7 +48,7 @@ class Collection:
"""
import llm
self.db = db
self.db = db or Database(memory=True)
self.name = name
self._model = model

View file

@ -21,7 +21,7 @@ def user_path(tmpdir):
def user_path_with_embeddings(user_path):
path = str(user_path / "embeddings.db")
db = sqlite_utils.Database(path)
collection = llm.Collection(db, "demo", model_id="embed-demo")
collection = llm.Collection("demo", db, model_id="embed-demo")
collection.embed("1", "hello world")
collection.embed("2", "goodbye world")

View file

@ -7,8 +7,7 @@ import pytest
@pytest.fixture
def collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
collection = llm.Collection("test", model_id="embed-demo")
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
return collection
@ -95,7 +94,7 @@ def test_similar_by_id(collection):
@pytest.mark.parametrize("with_metadata", (False, True))
def test_embed_multi(with_metadata):
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
collection = llm.Collection("test", db, model_id="embed-demo")
ids_and_texts = ((str(i), "hello {}".format(i)) for i in range(1000))
if with_metadata:
ids_and_texts = ((id, text, {"meta": id}) for id, text in ids_and_texts)