mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-12 09:11:00 +00:00
313 lines
10 KiB
Python
313 lines
10 KiB
Python
from .models import EmbeddingModel
|
|
from .embeddings_migrations import embeddings_migrations
|
|
from dataclasses import dataclass
|
|
import hashlib
|
|
from itertools import islice
|
|
import json
|
|
from sqlite_utils import Database
|
|
from sqlite_utils.db import Table
|
|
import time
|
|
from typing import cast, Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
|
|
@dataclass
|
|
class Entry:
|
|
id: str
|
|
score: Optional[float]
|
|
content: Optional[str] = None
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class Collection:
|
|
max_batch_size: int = 100
|
|
|
|
class DoesNotExist(Exception):
|
|
pass
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
db: Optional[Database] = None,
|
|
*,
|
|
model: Optional[EmbeddingModel] = None,
|
|
model_id: Optional[str] = None,
|
|
create: bool = True,
|
|
) -> None:
|
|
"""
|
|
A collection of embeddings
|
|
|
|
Returns the collection with the given name, creating it if it does not exist.
|
|
|
|
If you set create=False a Collection.DoesNotExist exception will be raised if the
|
|
collection does not already exist.
|
|
|
|
Args:
|
|
db (sqlite_utils.Database): Database to store the collection in
|
|
name (str): Name of the collection
|
|
model (llm.models.EmbeddingModel, optional): Embedding model to use
|
|
model_id (str, optional): Alternatively, ID of the embedding model to use
|
|
create (bool, optional): Whether to create the collection if it does not exist
|
|
"""
|
|
import llm
|
|
|
|
self.db = db or Database(memory=True)
|
|
self.name = name
|
|
self._model = model
|
|
|
|
embeddings_migrations.apply(self.db)
|
|
|
|
rows = list(self.db["collections"].rows_where("name = ?", [self.name]))
|
|
if rows:
|
|
row = rows[0]
|
|
self.id = row["id"]
|
|
self.model_id = row["model"]
|
|
else:
|
|
if create:
|
|
# Collection does not exist, so model or model_id is required
|
|
if not model and not model_id:
|
|
raise ValueError(
|
|
"Either model= or model_id= must be provided when creating a new collection"
|
|
)
|
|
# Create it
|
|
if model_id:
|
|
# Resolve alias
|
|
model = llm.get_embedding_model(model_id)
|
|
self._model = model
|
|
model_id = cast(EmbeddingModel, model).model_id
|
|
self.id = (
|
|
cast(Table, self.db["collections"])
|
|
.insert(
|
|
{
|
|
"name": self.name,
|
|
"model": model_id,
|
|
}
|
|
)
|
|
.last_pk
|
|
)
|
|
else:
|
|
raise self.DoesNotExist(f"Collection '{name}' does not exist")
|
|
|
|
def model(self) -> EmbeddingModel:
|
|
"Return the embedding model used by this collection"
|
|
import llm
|
|
|
|
if self._model is None:
|
|
self._model = llm.get_embedding_model(self.model_id)
|
|
|
|
return cast(EmbeddingModel, self._model)
|
|
|
|
def count(self) -> int:
|
|
"""
|
|
Count the number of items in the collection.
|
|
|
|
Returns:
|
|
int: Number of items in the collection
|
|
"""
|
|
return next(
|
|
self.db.query(
|
|
"""
|
|
select count(*) as c from embeddings where collection_id = (
|
|
select id from collections where name = ?
|
|
)
|
|
""",
|
|
(self.name,),
|
|
)
|
|
)["c"]
|
|
|
|
def embed(
|
|
self,
|
|
id: str,
|
|
text: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
store: bool = False,
|
|
) -> None:
|
|
"""
|
|
Embed text and store it in the collection with a given ID.
|
|
|
|
Args:
|
|
id (str): ID for the text
|
|
text (str): Text to be embedded
|
|
metadata (dict, optional): Metadata to be stored
|
|
store (bool, optional): Whether to store the text in the content column
|
|
"""
|
|
from llm import encode
|
|
|
|
embedding = self.model().embed(text)
|
|
cast(Table, self.db["embeddings"]).insert(
|
|
{
|
|
"collection_id": self.id,
|
|
"id": id,
|
|
"embedding": encode(embedding),
|
|
"content": text if store else None,
|
|
"content_hash": self.content_hash(text),
|
|
"metadata": json.dumps(metadata) if metadata else None,
|
|
"updated": int(time.time()),
|
|
},
|
|
replace=True,
|
|
)
|
|
|
|
def embed_multi(
|
|
self, entries: Iterable[Tuple[str, str]], store: bool = False
|
|
) -> None:
|
|
"""
|
|
Embed multiple texts and store them in the collection with given IDs.
|
|
|
|
Args:
|
|
entries (iterable): Iterable of (id: str, text: str) tuples
|
|
store (bool, optional): Whether to store the text in the content column
|
|
"""
|
|
self.embed_multi_with_metadata(
|
|
((id, text, None) for id, text in entries), store=store
|
|
)
|
|
|
|
def embed_multi_with_metadata(
|
|
self,
|
|
entries: Iterable[Tuple[str, str, Optional[Dict[str, Any]]]],
|
|
store: bool = False,
|
|
) -> None:
|
|
"""
|
|
Embed multiple texts along with metadata and store them in the collection with given IDs.
|
|
|
|
Args:
|
|
entries (iterable): Iterable of (id: str, text: str, metadata: None or dict)
|
|
store (bool, optional): Whether to store the text in the content column
|
|
"""
|
|
import llm
|
|
|
|
batch_size = min(
|
|
self.max_batch_size, (self.model().batch_size or self.max_batch_size)
|
|
)
|
|
iterator = iter(entries)
|
|
collection_id = self.id
|
|
while True:
|
|
batch = list(islice(iterator, batch_size))
|
|
if not batch:
|
|
break
|
|
embeddings = list(self.model().embed_multi(item[1] for item in batch))
|
|
with self.db.conn:
|
|
cast(Table, self.db["embeddings"]).insert_all(
|
|
(
|
|
{
|
|
"collection_id": collection_id,
|
|
"id": id,
|
|
"embedding": llm.encode(embedding),
|
|
"content": text if store else None,
|
|
"content_hash": self.content_hash(text),
|
|
"metadata": json.dumps(metadata) if metadata else None,
|
|
"updated": int(time.time()),
|
|
}
|
|
for (embedding, (id, text, metadata)) in zip(embeddings, batch)
|
|
),
|
|
replace=True,
|
|
)
|
|
|
|
def similar_by_vector(
|
|
self, vector: List[float], number: int = 10, skip_id: Optional[str] = None
|
|
) -> List[Entry]:
|
|
"""
|
|
Find similar items in the collection by a given vector.
|
|
|
|
Args:
|
|
vector (list): Vector to search by
|
|
number (int, optional): Number of similar items to return
|
|
|
|
Returns:
|
|
list: List of Entry objects
|
|
"""
|
|
import llm
|
|
|
|
def distance_score(other_encoded):
|
|
other_vector = llm.decode(other_encoded)
|
|
return llm.cosine_similarity(other_vector, vector)
|
|
|
|
self.db.register_function(distance_score, replace=True)
|
|
|
|
where_bits = ["collection_id = ?"]
|
|
where_args = [str(self.id)]
|
|
|
|
if skip_id:
|
|
where_bits.append("id != ?")
|
|
where_args.append(skip_id)
|
|
|
|
return [
|
|
Entry(
|
|
id=row["id"],
|
|
score=row["score"],
|
|
content=row["content"],
|
|
metadata=json.loads(row["metadata"]) if row["metadata"] else None,
|
|
)
|
|
for row in self.db.query(
|
|
"""
|
|
select id, content, metadata, distance_score(embedding) as score
|
|
from embeddings
|
|
where {where}
|
|
order by score desc limit {number}
|
|
""".format(
|
|
where=" and ".join(where_bits),
|
|
number=number,
|
|
),
|
|
where_args,
|
|
)
|
|
]
|
|
|
|
def similar_by_id(self, id: str, number: int = 10) -> List[Entry]:
|
|
"""
|
|
Find similar items in the collection by a given ID.
|
|
|
|
Args:
|
|
id (str): ID to search by
|
|
number (int, optional): Number of similar items to return
|
|
|
|
Returns:
|
|
list: List of Entry objects
|
|
"""
|
|
import llm
|
|
|
|
matches = list(
|
|
self.db["embeddings"].rows_where(
|
|
"collection_id = ? and id = ?", (self.id, id)
|
|
)
|
|
)
|
|
if not matches:
|
|
raise self.DoesNotExist("ID not found")
|
|
embedding = matches[0]["embedding"]
|
|
comparison_vector = llm.decode(embedding)
|
|
return self.similar_by_vector(comparison_vector, number, skip_id=id)
|
|
|
|
def similar(self, text: str, number: int = 10) -> List[Entry]:
|
|
"""
|
|
Find similar items in the collection by a given text.
|
|
|
|
Args:
|
|
text (str): Text to search by
|
|
number (int, optional): Number of similar items to return
|
|
|
|
Returns:
|
|
list: List of Entry objects
|
|
"""
|
|
comparison_vector = self.model().embed(text)
|
|
return self.similar_by_vector(comparison_vector, number)
|
|
|
|
@classmethod
|
|
def exists(cls, db: Database, name: str) -> bool:
|
|
"""
|
|
Does this collection exist in the database?
|
|
|
|
Args:
|
|
name (str): Name of the collection
|
|
"""
|
|
rows = list(db["collections"].rows_where("name = ?", [name]))
|
|
return bool(rows)
|
|
|
|
def delete(self):
|
|
"""
|
|
Delete the collection and its embeddings from the database
|
|
"""
|
|
with self.db.conn:
|
|
self.db.execute("delete from embeddings where collection_id = ?", [self.id])
|
|
self.db.execute("delete from collections where id = ?", [self.id])
|
|
|
|
@staticmethod
|
|
def content_hash(text: str) -> bytes:
|
|
"Hash content for deduplication. Override to change hashing behavior."
|
|
return hashlib.md5(text.encode("utf8")).digest()
|