Collection design tweaks + llm embed/similar now use it, closes #204

This commit is contained in:
Simon Willison 2023-09-02 08:30:56 -07:00
parent 2f178a9e9b
commit 3d56d6cc24
4 changed files with 98 additions and 142 deletions

View file

@ -84,6 +84,23 @@ collection.embed_multi_with_metadata(
)
```
(embeddings-python-collection-class)=
### Collection class reference
A collection instance has the following properties and methods:
- `id` - the integer ID of the collection in the database
- `name` - the string name of the collection (unique in the database)
- `model_id` - the string ID of the embedding model used for this collection
- `model()` - returns the `EmbeddingModel` instance, based on that `model_id`
- `count()` - returns the integer number of items in the collection
- `embed(id: str, text: str, metadata: dict=None, store: bool=False)` - embeds the given string and stores it in the collection under the given ID. Can optionally include metadata (stored as JSON) and store the text content itself in the database table.
- `embed_multi(entries: Iterable, store: bool=False)` - see above
- `embed_multi_with_metadata(entries: Iterable, store: bool=False)` - see above
- `similar(query: str, number: int=10)` - returns a list of entries that are most similar to the embedding of the given query string
- `similar_by_id(id: str, number: int=10)` - returns a list of entries that are most similar to the embedding of the item with the given ID
- `similar_by_vector(vector: List[float], number: int=10, skip_id: str=None)` - returns a list of entries that are most similar to the given embedding vector, optionally skipping the entry with the given ID
(embeddings-python-similar)=
## Retrieving similar items

View file

@ -22,7 +22,6 @@ from llm import (
)
from .migrations import migrate
from .embeddings_migrations import embeddings_migrations
from .plugins import pm
import base64
import pathlib
@ -30,7 +29,6 @@ import pydantic
from runpy import run_module
import shutil
import sqlite_utils
from sqlite_utils.db import NotFoundError
import sys
import textwrap
from typing import cast, Optional
@ -901,8 +899,6 @@ def embed(collection, id, input, model, store, database, content, format_):
if store and not collection:
raise click.ClickException("Must provide collection when using --store")
db = None
# Lazy load this because we do not need it for -c or -i versions
def get_db():
if database:
@ -910,40 +906,20 @@ def embed(collection, id, input, model, store, database, content, format_):
else:
return sqlite_utils.Database(user_dir() / "embeddings.db")
existing_collection = None
collection_obj = None
model_obj = None
if collection:
db = get_db()
if db["collections"].exists():
try:
existing_collection = get_collection(db, collection)
except NotFoundError:
pass
collection_obj = Collection(db, collection, model_id=model)
model_obj = collection_obj.model()
if model is None:
# If collection exists, use that model
if existing_collection:
model = existing_collection["model"]
else:
# Use default model
if model_obj is None:
if not model:
model = get_default_embedding_model()
if model and existing_collection:
# Resolve aliases before comparison
model_resolved = get_embedding_model(model).model_id
collection_model_resolved = get_embedding_model(
existing_collection["model"]
).model_id
if model_resolved != collection_model_resolved:
raise click.ClickException(
"Model '{}' does not match '{}' collection model of '{}'".format(
model, collection, existing_collection["model"]
)
)
try:
model = get_embedding_model(model)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
try:
model_obj = get_embedding_model(model)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
show_output = True
if collection and (format_ is None):
@ -958,34 +934,10 @@ def embed(collection, id, input, model, store, database, content, format_):
if not content:
raise click.ClickException("No content provided")
embedding = model.embed(content)
if collection:
# Store the embedding
if db is None:
db = get_db()
embeddings_migrations.apply(db)
if not existing_collection:
db["collections"].insert(
{
"name": collection,
"model": model.model_id,
}
)
existing_collection = get_collection(db, collection)
# Now store it
db["embeddings"].insert(
{
"collection_id": existing_collection["id"],
"id": id,
"content": content if store else None,
"embedding": encode(embedding),
},
replace=True,
)
if collection_obj:
embedding = collection_obj.embed(id, content, store=store)
else:
embedding = model_obj.embed(content)
if show_output:
if format_ == "json" or format_ is None:
@ -1042,19 +994,15 @@ def similar(collection, id, input, content, number, database):
if not db["embeddings"].exists():
raise click.ClickException("No embeddings table found in database")
collection_exists = False
try:
collection_obj = Collection(db, collection)
collection_exists = collection_obj.exists()
except ValueError:
collection_exists = False
if not collection_exists:
collection_obj = Collection(db, collection, create=False)
except Collection.DoesNotExist:
raise click.ClickException("Collection does not exist")
if id:
try:
results = collection_obj.similar_by_id(id, number)
except ValueError:
except Collection.DoesNotExist:
raise click.ClickException("ID not found in collection")
else:
if not content:
@ -1157,14 +1105,6 @@ def embed_db_collections(database, json_):
)
def get_collection(db, collection):
rows = db["collections"].rows_where("name = ?", [collection])
try:
return next(rows)
except StopIteration:
raise NotFoundError("Collection not found: {}".format(collection))
def template_dir():
path = user_dir() / "templates"
path.mkdir(parents=True, exist_ok=True)

View file

@ -19,6 +19,9 @@ class Entry:
class Collection:
max_batch_size: int = 100
class DoesNotExist(Exception):
pass
def __init__(
self,
db: Database,
@ -26,71 +29,65 @@ class Collection:
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
create: bool = True,
) -> None:
self.db = db
self.name = name
if model and model_id and model.model_id != model_id:
raise ValueError("model_id does not match model.model_id")
self._model = model
self._model_id = model_id
self._id = None
self._id = self.id()
"""
A collection of embeddings
def model(self) -> EmbeddingModel:
Returns the collection with the given name, creating it if it does not exist.
If you set create=False a Collection.DoesNotExist exception will be raised if the
collection does not already exist.
Args:
db (sqlite_utils.Database): Database to store the collection in
name (str): Name of the collection
model (llm.models.EmbeddingModel, optional): Embedding model to use
model_id (str, optional): Alternatively, ID of the embedding model to use
create (bool, optional): Whether to create the collection if it does not exist
"""
import llm
if self._model:
return self._model
try:
if not self._model_id:
raise ValueError("No model_id specified")
self._model = llm.get_embedding_model(self._model_id)
except llm.UnknownModelError:
raise ValueError("No model_id specified and no model found with that name")
return cast(EmbeddingModel, self._model)
self.db = db
self.name = name
self._model = model
def id(self) -> int:
"""
Get the ID of the collection, creating it in the DB if necessary.
embeddings_migrations.apply(self.db)
Returns:
int: ID of the collection
"""
if self._id is not None:
return self._id
if not self.db["collections"].exists():
embeddings_migrations.apply(self.db)
rows = self.db["collections"].rows_where("name = ?", [self.name])
try:
row = next(rows)
self._id = row["id"]
if self._model_id is None:
self._model_id = row["model"]
except StopIteration:
# Create it
self._id = (
cast(Table, self.db["collections"])
.insert(
{
"name": self.name,
"model": self.model().model_id,
}
rows = list(self.db["collections"].rows_where("name = ?", [self.name]))
if rows:
row = rows[0]
self.id = row["id"]
self.model_id = row["model"]
else:
if create:
# Create it
if model_id:
# Resolve alias
model = llm.get_embedding_model(model_id)
self._model = model
model_id = cast(EmbeddingModel, model).model_id
self.id = (
cast(Table, self.db["collections"])
.insert(
{
"name": self.name,
"model": model_id,
}
)
.last_pk
)
.last_pk
)
return cast(int, self._id)
else:
raise self.DoesNotExist(f"Collection '{name}' does not exist")
def exists(self) -> bool:
"""
Check if the collection exists in the DB.
def model(self) -> EmbeddingModel:
"Return the embedding model used by this collection"
import llm
Returns:
bool: True if exists, False otherwise
"""
matches = list(
self.db.query("select 1 from collections where name = ?", (self.name,))
)
return bool(matches)
if self._model is None:
self._model = llm.get_embedding_model(self.model_id)
return cast(EmbeddingModel, self._model)
def count(self) -> int:
"""
@ -118,7 +115,7 @@ class Collection:
store: bool = False,
) -> None:
"""
Embed a text and store it in the collection with a given ID.
Embed text and store it in the collection with a given ID.
Args:
id (str): ID for the text
@ -131,12 +128,13 @@ class Collection:
embedding = self.model().embed(text)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id(),
"collection_id": self.id,
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"metadata": json.dumps(metadata) if metadata else None,
}
},
replace=True,
)
def embed_multi(
@ -171,7 +169,7 @@ class Collection:
self.max_batch_size, (self.model().batch_size or self.max_batch_size)
)
iterator = iter(entries)
collection_id = self.id()
collection_id = self.id
while True:
batch = list(islice(iterator, batch_size))
if not batch:
@ -188,7 +186,8 @@ class Collection:
"metadata": json.dumps(metadata) if metadata else None,
}
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
)
),
replace=True,
)
def similar_by_vector(
@ -213,7 +212,7 @@ class Collection:
self.db.register_function(distance_score, replace=True)
where_bits = ["collection_id = ?"]
where_args = [str(self.id())]
where_args = [str(self.id)]
if skip_id:
where_bits.append("id != ?")
@ -255,11 +254,11 @@ class Collection:
matches = list(
self.db["embeddings"].rows_where(
"collection_id = ? and id = ?", (self.id(), id)
"collection_id = ? and id = ?", (self.id, id)
)
)
if not matches:
raise ValueError("ID not found")
raise self.DoesNotExist("ID not found")
embedding = matches[0]["embedding"]
comparison_vector = llm.decode(embedding)
return self.similar_by_vector(comparison_vector, number, skip_id=id)

View file

@ -55,7 +55,7 @@ def test_embed_metadata(collection):
def test_collection(collection):
assert collection.id() == 1
assert collection.id == 1
assert collection.count() == 2
# Check that the embeddings are there
rows = list(collection.db["embeddings"].rows)