From 1d5d73481a10a85b3f52651285024832bd2b9760 Mon Sep 17 00:00:00 2001 From: Dan Turkel Date: Sat, 24 May 2025 01:54:55 -0400 Subject: [PATCH] Filter IDs with `--prefix` in `llm similar` (#1052) --- docs/embeddings/cli.md | 8 +++++++- docs/help.md | 1 + llm/cli.py | 7 ++++--- llm/embeddings.py | 28 +++++++++++++++++++++++----- tests/test_embed.py | 7 +++++++ tests/test_embed_cli.py | 36 ++++++++++++++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 9 deletions(-) diff --git a/docs/embeddings/cli.md b/docs/embeddings/cli.md index 1228ec6..a6da55e 100644 --- a/docs/embeddings/cli.md +++ b/docs/embeddings/cli.md @@ -347,7 +347,7 @@ This embeds the provided string and returns a newline-delimited list of JSON obj ``` Use `-p/--plain` to get back results in plain text instead of JSON: ```bash -llm -similar quotations -c 'computer science' -p +llm similar quotations -c 'computer science' -p ``` Example output: ``` @@ -366,6 +366,12 @@ When using a model like CLIP, you can find images similar to an input image usin llm similar photos -i image.jpg --binary ``` +You can filter results to only show IDs that begin with a specific prefix using --prefix: + +```bash +llm similar quotations --prefix 'movies/' -c 'star wars' +``` + (embeddings-cli-embed-models)= ## llm embed-models diff --git a/docs/help.md b/docs/help.md index 8478b6b..e67deda 100644 --- a/docs/help.md +++ b/docs/help.md @@ -953,6 +953,7 @@ Options: -n, --number INTEGER Number of results to return -p, --plain Output in plain text format -d, --database FILE + --prefix TEXT Just IDs with this prefix --help Show this message and exit. ``` diff --git a/llm/cli.py b/llm/cli.py index 9b31660..c867639 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -3137,7 +3137,8 @@ def embed_multi( type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True), envvar="LLM_EMBEDDINGS_DB", ) -def similar(collection, id, input, content, binary, number, plain, database): +@click.option("--prefix", help="Just IDs with this prefix", default="") +def similar(collection, id, input, content, binary, number, plain, database, prefix): """ Return top N similar IDs from a collection using cosine similarity. @@ -3169,7 +3170,7 @@ def similar(collection, id, input, content, binary, number, plain, database): if id: try: - results = collection_obj.similar_by_id(id, number) + results = collection_obj.similar_by_id(id, number, prefix=prefix) except Collection.DoesNotExist: raise click.ClickException("ID not found in collection") else: @@ -3185,7 +3186,7 @@ def similar(collection, id, input, content, binary, number, plain, database): content = f.read() if not content: raise click.ClickException("No content provided") - results = collection_obj.similar(content, number) + results = collection_obj.similar(content, number, prefix=prefix) for result in results: if plain: diff --git a/llm/embeddings.py b/llm/embeddings.py index 5efeda0..5c9bf8f 100644 --- a/llm/embeddings.py +++ b/llm/embeddings.py @@ -238,7 +238,11 @@ class Collection: ) def similar_by_vector( - self, vector: List[float], number: int = 10, skip_id: Optional[str] = None + self, + vector: List[float], + number: int = 10, + skip_id: Optional[str] = None, + prefix: Optional[str] = None, ) -> List[Entry]: """ Find similar items in the collection by a given vector. @@ -246,6 +250,8 @@ class Collection: Args: vector (list): Vector to search by number (int, optional): Number of similar items to return + skip_id (str, optional): An ID to exclude from the results + prefix: (str, optional): Filter results to IDs witih this prefix Returns: list: List of Entry objects @@ -261,6 +267,10 @@ class Collection: where_bits = ["collection_id = ?"] where_args = [str(self.id)] + if prefix: + where_bits.append("id LIKE ? || '%'") + where_args.append(prefix) + if skip_id: where_bits.append("id != ?") where_args.append(skip_id) @@ -286,13 +296,16 @@ class Collection: ) ] - def similar_by_id(self, id: str, number: int = 10) -> List[Entry]: + def similar_by_id( + self, id: str, number: int = 10, prefix: Optional[str] = None + ) -> List[Entry]: """ Find similar items in the collection by a given ID. Args: id (str): ID to search by number (int, optional): Number of similar items to return + prefix: (str, optional): Filter results to IDs with this prefix Returns: list: List of Entry objects @@ -308,21 +321,26 @@ class Collection: raise self.DoesNotExist("ID not found") embedding = matches[0]["embedding"] comparison_vector = llm.decode(embedding) - return self.similar_by_vector(comparison_vector, number, skip_id=id) + return self.similar_by_vector( + comparison_vector, number, skip_id=id, prefix=prefix + ) - def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]: + def similar( + self, value: Union[str, bytes], number: int = 10, prefix: Optional[str] = None + ) -> List[Entry]: """ Find similar items in the collection by a given value. Args: value (str or bytes): value to search by number (int, optional): Number of similar items to return + prefix: (str, optional): Filter results to IDs with this prefix Returns: list: List of Entry objects """ comparison_vector = self.model().embed(value) - return self.similar_by_vector(comparison_vector, number) + return self.similar_by_vector(comparison_vector, number, prefix=prefix) @classmethod def exists(cls, db: Database, name: str) -> bool: diff --git a/tests/test_embed.py b/tests/test_embed.py index 50b93c6..9b9c809 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -93,6 +93,13 @@ def test_similar(collection): ] +def test_similar_prefixed(collection): + results = list(collection.similar("hello world", prefix="2")) + assert results == [ + Entry(id="2", score=pytest.approx(0.9863939238321437)), + ] + + def test_similar_by_id(collection): results = list(collection.similar_by_id("1")) assert results == [ diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index d993239..afee771 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -260,6 +260,42 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario): } +@pytest.mark.parametrize( + "prefix,expected_result", + ( + ( + 1, + { + "id": "1", + "score": pytest.approx(0.7071067811865475), + "content": "hello world", + "metadata": None, + }, + ), + ( + 2, + { + "id": "2", + "score": pytest.approx(0.8137334712067349), + "content": "goodbye world", + "metadata": None, + }, + ), + ), +) +def test_similar_by_content_prefixed( + user_path_with_embeddings, prefix, expected_result +): + runner = CliRunner() + result = runner.invoke( + cli, + ["similar", "demo", "-c", "world", "--prefix", prefix, "-n", "1"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert json.loads(result.output) == expected_result + + @pytest.mark.parametrize("use_stdin", (False, True)) @pytest.mark.parametrize("prefix", (None, "prefix")) @pytest.mark.parametrize("prepend", (None, "search_document: "))