mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-24 21:33:43 +00:00
Collection.similar methods, refs #191
This commit is contained in:
parent
7a4429f100
commit
0ec516559a
4 changed files with 116 additions and 72 deletions
|
|
@ -237,3 +237,10 @@ def encode(values):
|
|||
|
||||
def decode(binary):
|
||||
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
dot_product = sum(x * y for x, y in zip(a, b))
|
||||
magnitude_a = sum(x * x for x in a) ** 0.5
|
||||
magnitude_b = sum(x * x for x in b) ** 0.5
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
|
|
|||
59
llm/cli.py
59
llm/cli.py
|
|
@ -2,11 +2,11 @@ import click
|
|||
from click_default_group import DefaultGroup
|
||||
import json
|
||||
from llm import (
|
||||
Collection,
|
||||
Conversation,
|
||||
Response,
|
||||
Template,
|
||||
UnknownModelError,
|
||||
decode,
|
||||
encode,
|
||||
get_embedding_models_with_aliases,
|
||||
get_embedding_model,
|
||||
|
|
@ -1028,24 +1028,13 @@ def similar(collection, id, input, content, number, database):
|
|||
if not db["embeddings"].exists():
|
||||
raise click.ClickException("No embeddings table found in database")
|
||||
|
||||
try:
|
||||
collection_row = get_collection(db, collection)
|
||||
except NotFoundError:
|
||||
collection_obj = Collection(db, collection)
|
||||
if not collection_obj.exists():
|
||||
raise click.ClickException("Collection does not exist")
|
||||
|
||||
# If id was provided, we compare against that
|
||||
if id:
|
||||
matches = list(
|
||||
db["embeddings"].rows_where(
|
||||
"collection_id = ? and id = ?", (collection_row["id"], id)
|
||||
)
|
||||
)
|
||||
if not matches:
|
||||
raise click.ClickException("No match found for id: {}".format(id))
|
||||
embedding = matches[0]["embedding"]
|
||||
comparison_vector = decode(embedding)
|
||||
results = collection_obj.similar_by_id(id, number)
|
||||
else:
|
||||
# Embed the content that was provided instead
|
||||
if not content:
|
||||
if not input:
|
||||
# Read from stdin
|
||||
|
|
@ -1053,38 +1042,7 @@ def similar(collection, id, input, content, number, database):
|
|||
content = input.read()
|
||||
if not content:
|
||||
raise click.ClickException("No content provided")
|
||||
model = collection_row["model"]
|
||||
try:
|
||||
model = get_embedding_model(model)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
comparison_vector = model.embed(content)
|
||||
|
||||
def distance_score(other_encoded):
|
||||
other_vector = decode(other_encoded)
|
||||
return cosine_similarity(other_vector, comparison_vector)
|
||||
|
||||
db.register_function(distance_score)
|
||||
|
||||
where_bits = ["collection_id = ?"]
|
||||
where_args = [collection_row["id"]]
|
||||
|
||||
if id:
|
||||
where_bits.append("id != ?")
|
||||
where_args.append(id)
|
||||
|
||||
results = db.query(
|
||||
"""
|
||||
select id, distance_score(embedding) as score
|
||||
from embeddings
|
||||
where {where}
|
||||
order by score desc limit {number}
|
||||
""".format(
|
||||
where=" and ".join(where_bits),
|
||||
number=number,
|
||||
),
|
||||
where_args,
|
||||
)
|
||||
results = collection_obj.similar_by_content(content, number)
|
||||
|
||||
for result in results:
|
||||
click.echo(json.dumps(result))
|
||||
|
|
@ -1287,10 +1245,3 @@ def _human_readable_size(size_bytes):
|
|||
|
||||
def logs_on():
|
||||
return not (user_dir() / "logs-off").exists()
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
dot_product = sum(x * y for x, y in zip(a, b))
|
||||
magnitude_a = sum(x * x for x in a) ** 0.5
|
||||
magnitude_b = sum(x * x for x in b) ** 0.5
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
|
|
|||
|
|
@ -15,16 +15,25 @@ class Collection:
|
|||
model: Optional[EmbeddingModel] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
from llm import get_embedding_model
|
||||
|
||||
self.db = db
|
||||
self.name = name
|
||||
if model and model_id and model.model_id != model_id:
|
||||
raise ValueError("model_id does not match model.model_id")
|
||||
if model_id and not model:
|
||||
model = get_embedding_model(model_id)
|
||||
self.model = model
|
||||
self._id: Optional[int] = None
|
||||
self._model = model
|
||||
self._model_id = model_id
|
||||
self._id = None
|
||||
self._id = self.id()
|
||||
|
||||
def model(self) -> EmbeddingModel:
|
||||
import llm
|
||||
|
||||
if self._model:
|
||||
return self._model
|
||||
try:
|
||||
self._model = llm.get_embedding_model(self._model_id)
|
||||
except llm.UnknownModelError:
|
||||
raise ValueError("No model_id specified and no model found with that name")
|
||||
return cast(EmbeddingModel, self._model)
|
||||
|
||||
def id(self) -> int:
|
||||
"""
|
||||
|
|
@ -41,6 +50,8 @@ class Collection:
|
|||
try:
|
||||
row = next(rows)
|
||||
self._id = row["id"]
|
||||
if self._model_id is None:
|
||||
self._model_id = row["model"]
|
||||
except StopIteration:
|
||||
# Create it
|
||||
self._id = (
|
||||
|
|
@ -48,7 +59,7 @@ class Collection:
|
|||
.insert(
|
||||
{
|
||||
"name": self.name,
|
||||
"model": cast(EmbeddingModel, self.model).model_id,
|
||||
"model": self.model().model_id,
|
||||
}
|
||||
)
|
||||
.last_pk
|
||||
|
|
@ -103,7 +114,7 @@ class Collection:
|
|||
"""
|
||||
from llm import encode
|
||||
|
||||
embedding = cast(EmbeddingModel, self.model).embed(text)
|
||||
embedding = self.model().embed(text)
|
||||
cast(Table, self.db["embeddings"]).insert(
|
||||
{
|
||||
"collection_id": self.id(),
|
||||
|
|
@ -136,6 +147,50 @@ class Collection:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def similar_by_vector(
|
||||
self, vector: List[float], number: int = 5, skip_id: Optional[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find similar items in the collection by a given vector.
|
||||
|
||||
Args:
|
||||
vector (list): Vector to search by
|
||||
number (int, optional): Number of similar items to return
|
||||
|
||||
Returns:
|
||||
list: List of (id, score) tuples
|
||||
"""
|
||||
import llm
|
||||
|
||||
def distance_score(other_encoded):
|
||||
other_vector = llm.decode(other_encoded)
|
||||
return llm.cosine_similarity(other_vector, vector)
|
||||
|
||||
self.db.register_function(distance_score, replace=True)
|
||||
|
||||
where_bits = ["collection_id = ?"]
|
||||
where_args = [str(self.id())]
|
||||
|
||||
if skip_id:
|
||||
where_bits.append("id != ?")
|
||||
where_args.append(skip_id)
|
||||
|
||||
return [
|
||||
(row["id"], row["score"])
|
||||
for row in self.db.query(
|
||||
"""
|
||||
select id, distance_score(embedding) as score
|
||||
from embeddings
|
||||
where {where}
|
||||
order by score desc limit {number}
|
||||
""".format(
|
||||
where=" and ".join(where_bits),
|
||||
number=number,
|
||||
),
|
||||
where_args,
|
||||
)
|
||||
]
|
||||
|
||||
def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find similar items in the collection by a given ID.
|
||||
|
|
@ -147,7 +202,18 @@ class Collection:
|
|||
Returns:
|
||||
list: List of (id, score) tuples
|
||||
"""
|
||||
raise NotImplementedError
|
||||
import llm
|
||||
|
||||
matches = list(
|
||||
self.db["embeddings"].rows_where(
|
||||
"collection_id = ? and id = ?", (self.id(), id)
|
||||
)
|
||||
)
|
||||
if not matches:
|
||||
raise ValueError("ID not found")
|
||||
embedding = matches[0]["embedding"]
|
||||
comparison_vector = llm.decode(embedding)
|
||||
return self.similar_by_vector(comparison_vector, number, skip_id=id)
|
||||
|
||||
def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
|
|
@ -160,4 +226,5 @@ class Collection:
|
|||
Returns:
|
||||
list: List of (id, score) tuples
|
||||
"""
|
||||
raise NotImplementedError
|
||||
comparison_vector = self.model().embed(text)
|
||||
return self.similar_by_vector(comparison_vector, number)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,15 @@
|
|||
import llm
|
||||
import sqlite_utils
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def collection():
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
collection = llm.Collection(db, "test", model_id="embed-demo")
|
||||
collection.embed(1, "hello world")
|
||||
collection.embed(2, "goodbye world")
|
||||
return collection
|
||||
|
||||
|
||||
def test_demo_plugin():
|
||||
|
|
@ -21,17 +31,11 @@ def test_embed_huge_list():
|
|||
assert model.batch_count == 100
|
||||
|
||||
|
||||
def test_collection():
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
collection = llm.Collection(db, "test", model_id="embed-demo")
|
||||
def test_collection(collection):
|
||||
assert collection.id() == 1
|
||||
assert collection.count() == 0
|
||||
# Embed some stuff
|
||||
collection.embed(1, "hello world")
|
||||
collection.embed(2, "goodbye world")
|
||||
assert collection.count() == 2
|
||||
# Check that the embeddings are there
|
||||
rows = list(db["embeddings"].rows)
|
||||
rows = list(collection.db["embeddings"].rows)
|
||||
assert rows == [
|
||||
{
|
||||
"collection_id": 1,
|
||||
|
|
@ -48,3 +52,18 @@ def test_collection():
|
|||
"metadata": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_similar(collection):
|
||||
results = list(collection.similar("hello world"))
|
||||
assert results == [
|
||||
{"id": "1", "score": pytest.approx(0.9999999999999999)},
|
||||
{"id": "2", "score": pytest.approx(0.9863939238321437)},
|
||||
]
|
||||
|
||||
|
||||
def test_similar_by_id(collection):
|
||||
results = list(collection.similar_by_id("1"))
|
||||
assert results == [
|
||||
{"id": "2", "score": pytest.approx(0.9863939238321437)},
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue