Filter IDs with --prefix in llm similar (#1052)

This commit is contained in:
Dan Turkel 2025-05-24 01:54:55 -04:00 committed by GitHub
parent d5f7bf9073
commit 1d5d73481a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 78 additions and 9 deletions

View file

@ -347,7 +347,7 @@ This embeds the provided string and returns a newline-delimited list of JSON obj
```
Use `-p/--plain` to get back results in plain text instead of JSON:
```bash
llm -similar quotations -c 'computer science' -p
llm similar quotations -c 'computer science' -p
```
Example output:
```
@ -366,6 +366,12 @@ When using a model like CLIP, you can find images similar to an input image usin
llm similar photos -i image.jpg --binary
```
You can filter results to only show IDs that begin with a specific prefix using --prefix:
```bash
llm similar quotations --prefix 'movies/' -c 'star wars'
```
(embeddings-cli-embed-models)=
## llm embed-models

View file

@ -953,6 +953,7 @@ Options:
-n, --number INTEGER Number of results to return
-p, --plain Output in plain text format
-d, --database FILE
--prefix TEXT Just IDs with this prefix
--help Show this message and exit.
```

View file

@ -3137,7 +3137,8 @@ def embed_multi(
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
def similar(collection, id, input, content, binary, number, plain, database):
@click.option("--prefix", help="Just IDs with this prefix", default="")
def similar(collection, id, input, content, binary, number, plain, database, prefix):
"""
Return top N similar IDs from a collection using cosine similarity.
@ -3169,7 +3170,7 @@ def similar(collection, id, input, content, binary, number, plain, database):
if id:
try:
results = collection_obj.similar_by_id(id, number)
results = collection_obj.similar_by_id(id, number, prefix=prefix)
except Collection.DoesNotExist:
raise click.ClickException("ID not found in collection")
else:
@ -3185,7 +3186,7 @@ def similar(collection, id, input, content, binary, number, plain, database):
content = f.read()
if not content:
raise click.ClickException("No content provided")
results = collection_obj.similar(content, number)
results = collection_obj.similar(content, number, prefix=prefix)
for result in results:
if plain:

View file

@ -238,7 +238,11 @@ class Collection:
)
def similar_by_vector(
self, vector: List[float], number: int = 10, skip_id: Optional[str] = None
self,
vector: List[float],
number: int = 10,
skip_id: Optional[str] = None,
prefix: Optional[str] = None,
) -> List[Entry]:
"""
Find similar items in the collection by a given vector.
@ -246,6 +250,8 @@ class Collection:
Args:
vector (list): Vector to search by
number (int, optional): Number of similar items to return
skip_id (str, optional): An ID to exclude from the results
prefix: (str, optional): Filter results to IDs witih this prefix
Returns:
list: List of Entry objects
@ -261,6 +267,10 @@ class Collection:
where_bits = ["collection_id = ?"]
where_args = [str(self.id)]
if prefix:
where_bits.append("id LIKE ? || '%'")
where_args.append(prefix)
if skip_id:
where_bits.append("id != ?")
where_args.append(skip_id)
@ -286,13 +296,16 @@ class Collection:
)
]
def similar_by_id(self, id: str, number: int = 10) -> List[Entry]:
def similar_by_id(
self, id: str, number: int = 10, prefix: Optional[str] = None
) -> List[Entry]:
"""
Find similar items in the collection by a given ID.
Args:
id (str): ID to search by
number (int, optional): Number of similar items to return
prefix: (str, optional): Filter results to IDs with this prefix
Returns:
list: List of Entry objects
@ -308,21 +321,26 @@ class Collection:
raise self.DoesNotExist("ID not found")
embedding = matches[0]["embedding"]
comparison_vector = llm.decode(embedding)
return self.similar_by_vector(comparison_vector, number, skip_id=id)
return self.similar_by_vector(
comparison_vector, number, skip_id=id, prefix=prefix
)
def similar(self, value: Union[str, bytes], number: int = 10) -> List[Entry]:
def similar(
self, value: Union[str, bytes], number: int = 10, prefix: Optional[str] = None
) -> List[Entry]:
"""
Find similar items in the collection by a given value.
Args:
value (str or bytes): value to search by
number (int, optional): Number of similar items to return
prefix: (str, optional): Filter results to IDs with this prefix
Returns:
list: List of Entry objects
"""
comparison_vector = self.model().embed(value)
return self.similar_by_vector(comparison_vector, number)
return self.similar_by_vector(comparison_vector, number, prefix=prefix)
@classmethod
def exists(cls, db: Database, name: str) -> bool:

View file

@ -93,6 +93,13 @@ def test_similar(collection):
]
def test_similar_prefixed(collection):
results = list(collection.similar("hello world", prefix="2"))
assert results == [
Entry(id="2", score=pytest.approx(0.9863939238321437)),
]
def test_similar_by_id(collection):
results = list(collection.similar_by_id("1"))
assert results == [

View file

@ -260,6 +260,42 @@ def test_similar_by_content_cli(tmpdir, user_path_with_embeddings, scenario):
}
@pytest.mark.parametrize(
"prefix,expected_result",
(
(
1,
{
"id": "1",
"score": pytest.approx(0.7071067811865475),
"content": "hello world",
"metadata": None,
},
),
(
2,
{
"id": "2",
"score": pytest.approx(0.8137334712067349),
"content": "goodbye world",
"metadata": None,
},
),
),
)
def test_similar_by_content_prefixed(
user_path_with_embeddings, prefix, expected_result
):
runner = CliRunner()
result = runner.invoke(
cli,
["similar", "demo", "-c", "world", "--prefix", prefix, "-n", "1"],
catch_exceptions=False,
)
assert result.exit_code == 0
assert json.loads(result.output) == expected_result
@pytest.mark.parametrize("use_stdin", (False, True))
@pytest.mark.parametrize("prefix", (None, "prefix"))
@pytest.mark.parametrize("prepend", (None, "search_document: "))