Collection.similar methods, refs #191

This commit is contained in:
Simon Willison 2023-09-01 16:26:58 -07:00
parent 7a4429f100
commit 0ec516559a
4 changed files with 116 additions and 72 deletions

View file

@ -237,3 +237,10 @@ def encode(values):
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
def cosine_similarity(a, b):
dot_product = sum(x * y for x, y in zip(a, b))
magnitude_a = sum(x * x for x in a) ** 0.5
magnitude_b = sum(x * x for x in b) ** 0.5
return dot_product / (magnitude_a * magnitude_b)

View file

@ -2,11 +2,11 @@ import click
from click_default_group import DefaultGroup
import json
from llm import (
Collection,
Conversation,
Response,
Template,
UnknownModelError,
decode,
encode,
get_embedding_models_with_aliases,
get_embedding_model,
@ -1028,24 +1028,13 @@ def similar(collection, id, input, content, number, database):
if not db["embeddings"].exists():
raise click.ClickException("No embeddings table found in database")
try:
collection_row = get_collection(db, collection)
except NotFoundError:
collection_obj = Collection(db, collection)
if not collection_obj.exists():
raise click.ClickException("Collection does not exist")
# If id was provided, we compare against that
if id:
matches = list(
db["embeddings"].rows_where(
"collection_id = ? and id = ?", (collection_row["id"], id)
)
)
if not matches:
raise click.ClickException("No match found for id: {}".format(id))
embedding = matches[0]["embedding"]
comparison_vector = decode(embedding)
results = collection_obj.similar_by_id(id, number)
else:
# Embed the content that was provided instead
if not content:
if not input:
# Read from stdin
@ -1053,38 +1042,7 @@ def similar(collection, id, input, content, number, database):
content = input.read()
if not content:
raise click.ClickException("No content provided")
model = collection_row["model"]
try:
model = get_embedding_model(model)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
comparison_vector = model.embed(content)
def distance_score(other_encoded):
other_vector = decode(other_encoded)
return cosine_similarity(other_vector, comparison_vector)
db.register_function(distance_score)
where_bits = ["collection_id = ?"]
where_args = [collection_row["id"]]
if id:
where_bits.append("id != ?")
where_args.append(id)
results = db.query(
"""
select id, distance_score(embedding) as score
from embeddings
where {where}
order by score desc limit {number}
""".format(
where=" and ".join(where_bits),
number=number,
),
where_args,
)
results = collection_obj.similar_by_content(content, number)
for result in results:
click.echo(json.dumps(result))
@ -1287,10 +1245,3 @@ def _human_readable_size(size_bytes):
def logs_on():
return not (user_dir() / "logs-off").exists()
def cosine_similarity(a, b):
dot_product = sum(x * y for x, y in zip(a, b))
magnitude_a = sum(x * x for x in a) ** 0.5
magnitude_b = sum(x * x for x in b) ** 0.5
return dot_product / (magnitude_a * magnitude_b)

View file

@ -15,16 +15,25 @@ class Collection:
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
) -> None:
from llm import get_embedding_model
self.db = db
self.name = name
if model and model_id and model.model_id != model_id:
raise ValueError("model_id does not match model.model_id")
if model_id and not model:
model = get_embedding_model(model_id)
self.model = model
self._id: Optional[int] = None
self._model = model
self._model_id = model_id
self._id = None
self._id = self.id()
def model(self) -> EmbeddingModel:
import llm
if self._model:
return self._model
try:
self._model = llm.get_embedding_model(self._model_id)
except llm.UnknownModelError:
raise ValueError("No model_id specified and no model found with that name")
return cast(EmbeddingModel, self._model)
def id(self) -> int:
"""
@ -41,6 +50,8 @@ class Collection:
try:
row = next(rows)
self._id = row["id"]
if self._model_id is None:
self._model_id = row["model"]
except StopIteration:
# Create it
self._id = (
@ -48,7 +59,7 @@ class Collection:
.insert(
{
"name": self.name,
"model": cast(EmbeddingModel, self.model).model_id,
"model": self.model().model_id,
}
)
.last_pk
@ -103,7 +114,7 @@ class Collection:
"""
from llm import encode
embedding = cast(EmbeddingModel, self.model).embed(text)
embedding = self.model().embed(text)
cast(Table, self.db["embeddings"]).insert(
{
"collection_id": self.id(),
@ -136,6 +147,50 @@ class Collection:
"""
raise NotImplementedError
def similar_by_vector(
self, vector: List[float], number: int = 5, skip_id: Optional[str] = None
) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given vector.
Args:
vector (list): Vector to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
import llm
def distance_score(other_encoded):
other_vector = llm.decode(other_encoded)
return llm.cosine_similarity(other_vector, vector)
self.db.register_function(distance_score, replace=True)
where_bits = ["collection_id = ?"]
where_args = [str(self.id())]
if skip_id:
where_bits.append("id != ?")
where_args.append(skip_id)
return [
(row["id"], row["score"])
for row in self.db.query(
"""
select id, distance_score(embedding) as score
from embeddings
where {where}
order by score desc limit {number}
""".format(
where=" and ".join(where_bits),
number=number,
),
where_args,
)
]
def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given ID.
@ -147,7 +202,18 @@ class Collection:
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError
import llm
matches = list(
self.db["embeddings"].rows_where(
"collection_id = ? and id = ?", (self.id(), id)
)
)
if not matches:
raise ValueError("ID not found")
embedding = matches[0]["embedding"]
comparison_vector = llm.decode(embedding)
return self.similar_by_vector(comparison_vector, number, skip_id=id)
def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
"""
@ -160,4 +226,5 @@ class Collection:
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError
comparison_vector = self.model().embed(text)
return self.similar_by_vector(comparison_vector, number)

View file

@ -1,5 +1,15 @@
import llm
import sqlite_utils
import pytest
@pytest.fixture
def collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
return collection
def test_demo_plugin():
@ -21,17 +31,11 @@ def test_embed_huge_list():
assert model.batch_count == 100
def test_collection():
db = sqlite_utils.Database(memory=True)
collection = llm.Collection(db, "test", model_id="embed-demo")
def test_collection(collection):
assert collection.id() == 1
assert collection.count() == 0
# Embed some stuff
collection.embed(1, "hello world")
collection.embed(2, "goodbye world")
assert collection.count() == 2
# Check that the embeddings are there
rows = list(db["embeddings"].rows)
rows = list(collection.db["embeddings"].rows)
assert rows == [
{
"collection_id": 1,
@ -48,3 +52,18 @@ def test_collection():
"metadata": None,
},
]
def test_similar(collection):
results = list(collection.similar("hello world"))
assert results == [
{"id": "1", "score": pytest.approx(0.9999999999999999)},
{"id": "2", "score": pytest.approx(0.9863939238321437)},
]
def test_similar_by_id(collection):
results = list(collection.similar_by_id("1"))
assert results == [
{"id": "2", "score": pytest.approx(0.9863939238321437)},
]