Initial Collection class plus test, refs #191

This commit is contained in:
Simon Willison 2023-09-01 13:04:05 -07:00
parent c25e7c4713
commit 6f761702dc
4 changed files with 205 additions and 9 deletions

View file

@ -13,6 +13,7 @@ from .models import (
Prompt,
Response,
)
from .embeddings import Collection
from .templates import Template
from .plugins import pm
import click
@ -20,12 +21,14 @@ from typing import Dict, List, Optional
import json
import os
import pathlib
import struct
__all__ = [
"hookimpl",
"get_model",
"get_key",
"user_dir",
"Collection",
"Conversation",
"Model",
"Options",
@ -226,3 +229,11 @@ def remove_alias(alias):
raise KeyError("No such alias: {}".format(alias))
del current[alias]
path.write_text(json.dumps(current, indent=4) + "\n")
def encode(values):
return struct.pack("<" + "f" * len(values), *values)
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)

View file

@ -6,6 +6,8 @@ from llm import (
Response,
Template,
UnknownModelError,
decode,
encode,
get_embedding_models_with_aliases,
get_embedding_model,
get_key,
@ -28,7 +30,6 @@ from runpy import run_module
import shutil
import sqlite_utils
from sqlite_utils.db import NotFoundError
import struct
import sys
import textwrap
from typing import cast, Optional
@ -1288,14 +1289,6 @@ def logs_on():
return not (user_dir() / "logs-off").exists()
def encode(values):
return struct.pack("<" + "f" * len(values), *values)
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
def cosine_similarity(a, b):
dot_product = sum(x * y for x, y in zip(a, b))
magnitude_a = sum(x * x for x in a) ** 0.5

162
llm/embeddings.py Normal file
View file

@ -0,0 +1,162 @@
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
import json
from sqlite_utils import Database
from typing import Any, Dict, List, Tuple, Optional, Union
class Collection:
def __init__(
self,
db: Database,
name: str,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
) -> None:
from llm import get_embedding_model
self.db = db
self.name = name
if model and model_id and model.model_id != model_id:
raise ValueError("model_id does not match model.model_id")
if model_id and not model:
model = get_embedding_model(model_id)
self.model = model
self._id = None
def id(self) -> int:
"""
Get the ID of the collection, creating it in the DB if necessary.
Returns:
int: ID of the collection
"""
if self._id is not None:
return self._id
if not self.db["collections"].exists():
embeddings_migrations.apply(self.db)
rows = self.db["collections"].rows_where("name = ?", [self.name])
try:
row = next(rows)
self._id = row["id"]
except StopIteration:
# Create it
self._id = (
self.db["collections"]
.insert(
{
"name": self.name,
"model": self.model.model_id,
}
)
.last_pk
)
return self._id
def exists(self) -> bool:
"""
Check if the collection exists in the DB.
Returns:
bool: True if exists, False otherwise
"""
matches = list(
self.db.query("select 1 from collections where name = ?", (self.name,))
)
return bool(matches)
def count(self) -> int:
"""
Count the number of items in the collection.
Returns:
int: Number of items in the collection
"""
return next(
self.db.query(
"""
select count(*) as c from embeddings where collection_id = (
select id from collections where name = ?
)
""",
(self.name,),
)
)["c"]
def embed(
self,
id: str,
text: str,
metadata: Optional[Dict[str, Any]] = None,
store: bool = False,
) -> None:
"""
Embed a text and store it in the collection with a given ID.
Args:
id (str): ID for the text
text (str): Text to be embedded
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the text in the content column
"""
from llm import encode
embedding = self.model.embed(text)
self.db["embeddings"].insert(
{
"collection_id": self.id(),
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"metadata": json.dumps(metadata) if metadata else None,
}
)
def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
"""
Embed multiple texts and store them in the collection with given IDs.
Args:
id_text_map (dict): Dictionary mapping IDs to texts
store (bool, optional): Whether to store the text in the content column
"""
raise NotImplementedError
def embed_multi_with_metadata(
self,
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
) -> None:
"""
Embed multiple texts along with metadata and store them in the collection with given IDs.
Args:
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
"""
raise NotImplementedError
def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given ID.
Args:
id (str): ID to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError
def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given text.
Args:
text (str): Text to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError

View file

@ -1,4 +1,5 @@
import llm
import sqlite_utils
def test_demo_plugin():
@ -18,3 +19,32 @@ def test_embed_huge_list():
assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}
# Should have happened in 100 batches
assert model.batch_count == 100
def test_collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
assert collection.id() == 1
assert collection.count() == 0
# Embed some stuff
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
assert collection.count() == 2
# Check that the embeddings are there
rows = list(db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
"id": "1",
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"metadata": None,
},
{
"collection_id": 1,
"id": "2",
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"content": None,
"metadata": None,
},
]