diff --git a/llm/__init__.py b/llm/__init__.py index 5785574..77756ca 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -237,3 +237,10 @@ def encode(values): def decode(binary): return struct.unpack("<" + "f" * (len(binary) // 4), binary) + + +def cosine_similarity(a, b): + dot_product = sum(x * y for x, y in zip(a, b)) + magnitude_a = sum(x * x for x in a) ** 0.5 + magnitude_b = sum(x * x for x in b) ** 0.5 + return dot_product / (magnitude_a * magnitude_b) diff --git a/llm/cli.py b/llm/cli.py index c99b9e9..b884f79 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -2,11 +2,11 @@ import click from click_default_group import DefaultGroup import json from llm import ( + Collection, Conversation, Response, Template, UnknownModelError, - decode, encode, get_embedding_models_with_aliases, get_embedding_model, @@ -1028,24 +1028,13 @@ def similar(collection, id, input, content, number, database): if not db["embeddings"].exists(): raise click.ClickException("No embeddings table found in database") - try: - collection_row = get_collection(db, collection) - except NotFoundError: + collection_obj = Collection(db, collection) + if not collection_obj.exists(): raise click.ClickException("Collection does not exist") - # If id was provided, we compare against that if id: - matches = list( - db["embeddings"].rows_where( - "collection_id = ? and id = ?", (collection_row["id"], id) - ) - ) - if not matches: - raise click.ClickException("No match found for id: {}".format(id)) - embedding = matches[0]["embedding"] - comparison_vector = decode(embedding) + results = collection_obj.similar_by_id(id, number) else: - # Embed the content that was provided instead if not content: if not input: # Read from stdin @@ -1053,38 +1042,7 @@ def similar(collection, id, input, content, number, database): content = input.read() if not content: raise click.ClickException("No content provided") - model = collection_row["model"] - try: - model = get_embedding_model(model) - except UnknownModelError as ex: - raise click.ClickException(str(ex)) - comparison_vector = model.embed(content) - - def distance_score(other_encoded): - other_vector = decode(other_encoded) - return cosine_similarity(other_vector, comparison_vector) - - db.register_function(distance_score) - - where_bits = ["collection_id = ?"] - where_args = [collection_row["id"]] - - if id: - where_bits.append("id != ?") - where_args.append(id) - - results = db.query( - """ - select id, distance_score(embedding) as score - from embeddings - where {where} - order by score desc limit {number} - """.format( - where=" and ".join(where_bits), - number=number, - ), - where_args, - ) + results = collection_obj.similar_by_content(content, number) for result in results: click.echo(json.dumps(result)) @@ -1287,10 +1245,3 @@ def _human_readable_size(size_bytes): def logs_on(): return not (user_dir() / "logs-off").exists() - - -def cosine_similarity(a, b): - dot_product = sum(x * y for x, y in zip(a, b)) - magnitude_a = sum(x * x for x in a) ** 0.5 - magnitude_b = sum(x * x for x in b) ** 0.5 - return dot_product / (magnitude_a * magnitude_b) diff --git a/llm/embeddings.py b/llm/embeddings.py index ee41afd..0e14f99 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -15,16 +15,25 @@ class Collection: model: Optional[EmbeddingModel] = None, model_id: Optional[str] = None, ) -> None: - from llm import get_embedding_model - self.db = db self.name = name if model and model_id and model.model_id != model_id: raise ValueError("model_id does not match model.model_id") - if model_id and not model: - model = get_embedding_model(model_id) - self.model = model - self._id: Optional[int] = None + self._model = model + self._model_id = model_id + self._id = None + self._id = self.id() + + def model(self) -> EmbeddingModel: + import llm + + if self._model: + return self._model + try: + self._model = llm.get_embedding_model(self._model_id) + except llm.UnknownModelError: + raise ValueError("No model_id specified and no model found with that name") + return cast(EmbeddingModel, self._model) def id(self) -> int: """ @@ -41,6 +50,8 @@ class Collection: try: row = next(rows) self._id = row["id"] + if self._model_id is None: + self._model_id = row["model"] except StopIteration: # Create it self._id = ( @@ -48,7 +59,7 @@ class Collection: .insert( { "name": self.name, - "model": cast(EmbeddingModel, self.model).model_id, + "model": self.model().model_id, } ) .last_pk @@ -103,7 +114,7 @@ class Collection: """ from llm import encode - embedding = cast(EmbeddingModel, self.model).embed(text) + embedding = self.model().embed(text) cast(Table, self.db["embeddings"]).insert( { "collection_id": self.id(), @@ -136,6 +147,50 @@ class Collection: """ raise NotImplementedError + def similar_by_vector( + self, vector: List[float], number: int = 5, skip_id: Optional[str] = None + ) -> List[Tuple[str, float]]: + """ + Find similar items in the collection by a given vector. + + Args: + vector (list): Vector to search by + number (int, optional): Number of similar items to return + + Returns: + list: List of (id, score) tuples + """ + import llm + + def distance_score(other_encoded): + other_vector = llm.decode(other_encoded) + return llm.cosine_similarity(other_vector, vector) + + self.db.register_function(distance_score, replace=True) + + where_bits = ["collection_id = ?"] + where_args = [str(self.id())] + + if skip_id: + where_bits.append("id != ?") + where_args.append(skip_id) + + return [ + (row["id"], row["score"]) + for row in self.db.query( + """ + select id, distance_score(embedding) as score + from embeddings + where {where} + order by score desc limit {number} + """.format( + where=" and ".join(where_bits), + number=number, + ), + where_args, + ) + ] + def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]: """ Find similar items in the collection by a given ID. @@ -147,7 +202,18 @@ class Collection: Returns: list: List of (id, score) tuples """ - raise NotImplementedError + import llm + + matches = list( + self.db["embeddings"].rows_where( + "collection_id = ? and id = ?", (self.id(), id) + ) + ) + if not matches: + raise ValueError("ID not found") + embedding = matches[0]["embedding"] + comparison_vector = llm.decode(embedding) + return self.similar_by_vector(comparison_vector, number, skip_id=id) def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]: """ @@ -160,4 +226,5 @@ class Collection: Returns: list: List of (id, score) tuples """ - raise NotImplementedError + comparison_vector = self.model().embed(text) + return self.similar_by_vector(comparison_vector, number) diff --git a/tests/test_embed.py b/tests/test_embed.py index 409e8ae..ea76b0c 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -1,5 +1,15 @@ import llm import sqlite_utils +import pytest + + +@pytest.fixture +def collection(): + db = sqlite_utils.Database(memory=True) + collection = llm.Collection(db, "test", model_id="embed-demo") + collection.embed(1, "hello world") + collection.embed(2, "goodbye world") + return collection def test_demo_plugin(): @@ -21,17 +31,11 @@ def test_embed_huge_list(): assert model.batch_count == 100 -def test_collection(): - db = sqlite_utils.Database(memory=True) - collection = llm.Collection(db, "test", model_id="embed-demo") +def test_collection(collection): assert collection.id() == 1 - assert collection.count() == 0 - # Embed some stuff - collection.embed(1, "hello world") - collection.embed(2, "goodbye world") assert collection.count() == 2 # Check that the embeddings are there - rows = list(db["embeddings"].rows) + rows = list(collection.db["embeddings"].rows) assert rows == [ { "collection_id": 1, @@ -48,3 +52,18 @@ def test_collection(): "metadata": None, }, ] + + +def test_similar(collection): + results = list(collection.similar("hello world")) + assert results == [ + {"id": "1", "score": pytest.approx(0.9999999999999999)}, + {"id": "2", "score": pytest.approx(0.9863939238321437)}, + ] + + +def test_similar_by_id(collection): + results = list(collection.similar_by_id("1")) + assert results == [ + {"id": "2", "score": pytest.approx(0.9863939238321437)}, + ]