mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-01 10:24:46 +00:00
Collection design tweaks + llm embed/similar now use it, closes #204
This commit is contained in:
parent
2f178a9e9b
commit
3d56d6cc24
4 changed files with 98 additions and 142 deletions
|
|
@ -84,6 +84,23 @@ collection.embed_multi_with_metadata(
|
|||
)
|
||||
```
|
||||
|
||||
(embeddings-python-collection-class)=
|
||||
### Collection class reference
|
||||
|
||||
A collection instance has the following properties and methods:
|
||||
|
||||
- `id` - the integer ID of the collection in the database
|
||||
- `name` - the string name of the collection (unique in the database)
|
||||
- `model_id` - the string ID of the embedding model used for this collection
|
||||
- `model()` - returns the `EmbeddingModel` instance, based on that `model_id`
|
||||
- `count()` - returns the integer number of items in the collection
|
||||
- `embed(id: str, text: str, metadata: dict=None, store: bool=False)` - embeds the given string and stores it in the collection under the given ID. Can optionally include metadata (stored as JSON) and store the text content itself in the database table.
|
||||
- `embed_multi(entries: Iterable, store: bool=False)` - see above
|
||||
- `embed_multi_with_metadata(entries: Iterable, store: bool=False)` - see above
|
||||
- `similar(query: str, number: int=10)` - returns a list of entries that are most similar to the embedding of the given query string
|
||||
- `similar_by_id(id: str, number: int=10)` - returns a list of entries that are most similar to the embedding of the item with the given ID
|
||||
- `similar_by_vector(vector: List[float], number: int=10, skip_id: str=None)` - returns a list of entries that are most similar to the given embedding vector, optionally skipping the entry with the given ID
|
||||
|
||||
(embeddings-python-similar)=
|
||||
## Retrieving similar items
|
||||
|
||||
|
|
|
|||
94
llm/cli.py
94
llm/cli.py
|
|
@ -22,7 +22,6 @@ from llm import (
|
|||
)
|
||||
|
||||
from .migrations import migrate
|
||||
from .embeddings_migrations import embeddings_migrations
|
||||
from .plugins import pm
|
||||
import base64
|
||||
import pathlib
|
||||
|
|
@ -30,7 +29,6 @@ import pydantic
|
|||
from runpy import run_module
|
||||
import shutil
|
||||
import sqlite_utils
|
||||
from sqlite_utils.db import NotFoundError
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional
|
||||
|
|
@ -901,8 +899,6 @@ def embed(collection, id, input, model, store, database, content, format_):
|
|||
if store and not collection:
|
||||
raise click.ClickException("Must provide collection when using --store")
|
||||
|
||||
db = None
|
||||
|
||||
# Lazy load this because we do not need it for -c or -i versions
|
||||
def get_db():
|
||||
if database:
|
||||
|
|
@ -910,40 +906,20 @@ def embed(collection, id, input, model, store, database, content, format_):
|
|||
else:
|
||||
return sqlite_utils.Database(user_dir() / "embeddings.db")
|
||||
|
||||
existing_collection = None
|
||||
collection_obj = None
|
||||
model_obj = None
|
||||
if collection:
|
||||
db = get_db()
|
||||
if db["collections"].exists():
|
||||
try:
|
||||
existing_collection = get_collection(db, collection)
|
||||
except NotFoundError:
|
||||
pass
|
||||
collection_obj = Collection(db, collection, model_id=model)
|
||||
model_obj = collection_obj.model()
|
||||
|
||||
if model is None:
|
||||
# If collection exists, use that model
|
||||
if existing_collection:
|
||||
model = existing_collection["model"]
|
||||
else:
|
||||
# Use default model
|
||||
if model_obj is None:
|
||||
if not model:
|
||||
model = get_default_embedding_model()
|
||||
|
||||
if model and existing_collection:
|
||||
# Resolve aliases before comparison
|
||||
model_resolved = get_embedding_model(model).model_id
|
||||
collection_model_resolved = get_embedding_model(
|
||||
existing_collection["model"]
|
||||
).model_id
|
||||
if model_resolved != collection_model_resolved:
|
||||
raise click.ClickException(
|
||||
"Model '{}' does not match '{}' collection model of '{}'".format(
|
||||
model, collection, existing_collection["model"]
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
model = get_embedding_model(model)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
try:
|
||||
model_obj = get_embedding_model(model)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
show_output = True
|
||||
if collection and (format_ is None):
|
||||
|
|
@ -958,34 +934,10 @@ def embed(collection, id, input, model, store, database, content, format_):
|
|||
if not content:
|
||||
raise click.ClickException("No content provided")
|
||||
|
||||
embedding = model.embed(content)
|
||||
|
||||
if collection:
|
||||
# Store the embedding
|
||||
if db is None:
|
||||
db = get_db()
|
||||
|
||||
embeddings_migrations.apply(db)
|
||||
|
||||
if not existing_collection:
|
||||
db["collections"].insert(
|
||||
{
|
||||
"name": collection,
|
||||
"model": model.model_id,
|
||||
}
|
||||
)
|
||||
existing_collection = get_collection(db, collection)
|
||||
|
||||
# Now store it
|
||||
db["embeddings"].insert(
|
||||
{
|
||||
"collection_id": existing_collection["id"],
|
||||
"id": id,
|
||||
"content": content if store else None,
|
||||
"embedding": encode(embedding),
|
||||
},
|
||||
replace=True,
|
||||
)
|
||||
if collection_obj:
|
||||
embedding = collection_obj.embed(id, content, store=store)
|
||||
else:
|
||||
embedding = model_obj.embed(content)
|
||||
|
||||
if show_output:
|
||||
if format_ == "json" or format_ is None:
|
||||
|
|
@ -1042,19 +994,15 @@ def similar(collection, id, input, content, number, database):
|
|||
if not db["embeddings"].exists():
|
||||
raise click.ClickException("No embeddings table found in database")
|
||||
|
||||
collection_exists = False
|
||||
try:
|
||||
collection_obj = Collection(db, collection)
|
||||
collection_exists = collection_obj.exists()
|
||||
except ValueError:
|
||||
collection_exists = False
|
||||
if not collection_exists:
|
||||
collection_obj = Collection(db, collection, create=False)
|
||||
except Collection.DoesNotExist:
|
||||
raise click.ClickException("Collection does not exist")
|
||||
|
||||
if id:
|
||||
try:
|
||||
results = collection_obj.similar_by_id(id, number)
|
||||
except ValueError:
|
||||
except Collection.DoesNotExist:
|
||||
raise click.ClickException("ID not found in collection")
|
||||
else:
|
||||
if not content:
|
||||
|
|
@ -1157,14 +1105,6 @@ def embed_db_collections(database, json_):
|
|||
)
|
||||
|
||||
|
||||
def get_collection(db, collection):
|
||||
rows = db["collections"].rows_where("name = ?", [collection])
|
||||
try:
|
||||
return next(rows)
|
||||
except StopIteration:
|
||||
raise NotFoundError("Collection not found: {}".format(collection))
|
||||
|
||||
|
||||
def template_dir():
|
||||
path = user_dir() / "templates"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ class Entry:
|
|||
class Collection:
|
||||
max_batch_size: int = 100
|
||||
|
||||
class DoesNotExist(Exception):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
|
|
@ -26,71 +29,65 @@ class Collection:
|
|||
*,
|
||||
model: Optional[EmbeddingModel] = None,
|
||||
model_id: Optional[str] = None,
|
||||
create: bool = True,
|
||||
) -> None:
|
||||
self.db = db
|
||||
self.name = name
|
||||
if model and model_id and model.model_id != model_id:
|
||||
raise ValueError("model_id does not match model.model_id")
|
||||
self._model = model
|
||||
self._model_id = model_id
|
||||
self._id = None
|
||||
self._id = self.id()
|
||||
"""
|
||||
A collection of embeddings
|
||||
|
||||
def model(self) -> EmbeddingModel:
|
||||
Returns the collection with the given name, creating it if it does not exist.
|
||||
|
||||
If you set create=False a Collection.DoesNotExist exception will be raised if the
|
||||
collection does not already exist.
|
||||
|
||||
Args:
|
||||
db (sqlite_utils.Database): Database to store the collection in
|
||||
name (str): Name of the collection
|
||||
model (llm.models.EmbeddingModel, optional): Embedding model to use
|
||||
model_id (str, optional): Alternatively, ID of the embedding model to use
|
||||
create (bool, optional): Whether to create the collection if it does not exist
|
||||
"""
|
||||
import llm
|
||||
|
||||
if self._model:
|
||||
return self._model
|
||||
try:
|
||||
if not self._model_id:
|
||||
raise ValueError("No model_id specified")
|
||||
self._model = llm.get_embedding_model(self._model_id)
|
||||
except llm.UnknownModelError:
|
||||
raise ValueError("No model_id specified and no model found with that name")
|
||||
return cast(EmbeddingModel, self._model)
|
||||
self.db = db
|
||||
self.name = name
|
||||
self._model = model
|
||||
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Get the ID of the collection, creating it in the DB if necessary.
|
||||
embeddings_migrations.apply(self.db)
|
||||
|
||||
Returns:
|
||||
int: ID of the collection
|
||||
"""
|
||||
if self._id is not None:
|
||||
return self._id
|
||||
if not self.db["collections"].exists():
|
||||
embeddings_migrations.apply(self.db)
|
||||
rows = self.db["collections"].rows_where("name = ?", [self.name])
|
||||
try:
|
||||
row = next(rows)
|
||||
self._id = row["id"]
|
||||
if self._model_id is None:
|
||||
self._model_id = row["model"]
|
||||
except StopIteration:
|
||||
# Create it
|
||||
self._id = (
|
||||
cast(Table, self.db["collections"])
|
||||
.insert(
|
||||
{
|
||||
"name": self.name,
|
||||
"model": self.model().model_id,
|
||||
}
|
||||
rows = list(self.db["collections"].rows_where("name = ?", [self.name]))
|
||||
if rows:
|
||||
row = rows[0]
|
||||
self.id = row["id"]
|
||||
self.model_id = row["model"]
|
||||
else:
|
||||
if create:
|
||||
# Create it
|
||||
if model_id:
|
||||
# Resolve alias
|
||||
model = llm.get_embedding_model(model_id)
|
||||
self._model = model
|
||||
model_id = cast(EmbeddingModel, model).model_id
|
||||
self.id = (
|
||||
cast(Table, self.db["collections"])
|
||||
.insert(
|
||||
{
|
||||
"name": self.name,
|
||||
"model": model_id,
|
||||
}
|
||||
)
|
||||
.last_pk
|
||||
)
|
||||
.last_pk
|
||||
)
|
||||
return cast(int, self._id)
|
||||
else:
|
||||
raise self.DoesNotExist(f"Collection '{name}' does not exist")
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""
|
||||
Check if the collection exists in the DB.
|
||||
def model(self) -> EmbeddingModel:
|
||||
"Return the embedding model used by this collection"
|
||||
import llm
|
||||
|
||||
Returns:
|
||||
bool: True if exists, False otherwise
|
||||
"""
|
||||
matches = list(
|
||||
self.db.query("select 1 from collections where name = ?", (self.name,))
|
||||
)
|
||||
return bool(matches)
|
||||
if self._model is None:
|
||||
self._model = llm.get_embedding_model(self.model_id)
|
||||
|
||||
return cast(EmbeddingModel, self._model)
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
|
|
@ -118,7 +115,7 @@ class Collection:
|
|||
store: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Embed a text and store it in the collection with a given ID.
|
||||
Embed text and store it in the collection with a given ID.
|
||||
|
||||
Args:
|
||||
id (str): ID for the text
|
||||
|
|
@ -131,12 +128,13 @@ class Collection:
|
|||
embedding = self.model().embed(text)
|
||||
cast(Table, self.db["embeddings"]).insert(
|
||||
{
|
||||
"collection_id": self.id(),
|
||||
"collection_id": self.id,
|
||||
"id": id,
|
||||
"embedding": encode(embedding),
|
||||
"content": text if store else None,
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
}
|
||||
},
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def embed_multi(
|
||||
|
|
@ -171,7 +169,7 @@ class Collection:
|
|||
self.max_batch_size, (self.model().batch_size or self.max_batch_size)
|
||||
)
|
||||
iterator = iter(entries)
|
||||
collection_id = self.id()
|
||||
collection_id = self.id
|
||||
while True:
|
||||
batch = list(islice(iterator, batch_size))
|
||||
if not batch:
|
||||
|
|
@ -188,7 +186,8 @@ class Collection:
|
|||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
}
|
||||
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
|
||||
)
|
||||
),
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def similar_by_vector(
|
||||
|
|
@ -213,7 +212,7 @@ class Collection:
|
|||
self.db.register_function(distance_score, replace=True)
|
||||
|
||||
where_bits = ["collection_id = ?"]
|
||||
where_args = [str(self.id())]
|
||||
where_args = [str(self.id)]
|
||||
|
||||
if skip_id:
|
||||
where_bits.append("id != ?")
|
||||
|
|
@ -255,11 +254,11 @@ class Collection:
|
|||
|
||||
matches = list(
|
||||
self.db["embeddings"].rows_where(
|
||||
"collection_id = ? and id = ?", (self.id(), id)
|
||||
"collection_id = ? and id = ?", (self.id, id)
|
||||
)
|
||||
)
|
||||
if not matches:
|
||||
raise ValueError("ID not found")
|
||||
raise self.DoesNotExist("ID not found")
|
||||
embedding = matches[0]["embedding"]
|
||||
comparison_vector = llm.decode(embedding)
|
||||
return self.similar_by_vector(comparison_vector, number, skip_id=id)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def test_embed_metadata(collection):
|
|||
|
||||
|
||||
def test_collection(collection):
|
||||
assert collection.id() == 1
|
||||
assert collection.id == 1
|
||||
assert collection.count() == 2
|
||||
# Check that the embeddings are there
|
||||
rows = list(collection.db["embeddings"].rows)
|
||||
|
|
|
|||
Loading…
Reference in a new issue