mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-27 08:24:47 +00:00
Initial Collection class plus test, refs #191
This commit is contained in:
parent
c25e7c4713
commit
6f761702dc
4 changed files with 205 additions and 9 deletions
|
|
@ -13,6 +13,7 @@ from .models import (
|
|||
Prompt,
|
||||
Response,
|
||||
)
|
||||
from .embeddings import Collection
|
||||
from .templates import Template
|
||||
from .plugins import pm
|
||||
import click
|
||||
|
|
@ -20,12 +21,14 @@ from typing import Dict, List, Optional
|
|||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import struct
|
||||
|
||||
__all__ = [
|
||||
"hookimpl",
|
||||
"get_model",
|
||||
"get_key",
|
||||
"user_dir",
|
||||
"Collection",
|
||||
"Conversation",
|
||||
"Model",
|
||||
"Options",
|
||||
|
|
@ -226,3 +229,11 @@ def remove_alias(alias):
|
|||
raise KeyError("No such alias: {}".format(alias))
|
||||
del current[alias]
|
||||
path.write_text(json.dumps(current, indent=4) + "\n")
|
||||
|
||||
|
||||
def encode(values):
|
||||
return struct.pack("<" + "f" * len(values), *values)
|
||||
|
||||
|
||||
def decode(binary):
|
||||
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
|
||||
|
|
|
|||
11
llm/cli.py
11
llm/cli.py
|
|
@ -6,6 +6,8 @@ from llm import (
|
|||
Response,
|
||||
Template,
|
||||
UnknownModelError,
|
||||
decode,
|
||||
encode,
|
||||
get_embedding_models_with_aliases,
|
||||
get_embedding_model,
|
||||
get_key,
|
||||
|
|
@ -28,7 +30,6 @@ from runpy import run_module
|
|||
import shutil
|
||||
import sqlite_utils
|
||||
from sqlite_utils.db import NotFoundError
|
||||
import struct
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional
|
||||
|
|
@ -1288,14 +1289,6 @@ def logs_on():
|
|||
return not (user_dir() / "logs-off").exists()
|
||||
|
||||
|
||||
def encode(values):
|
||||
return struct.pack("<" + "f" * len(values), *values)
|
||||
|
||||
|
||||
def decode(binary):
|
||||
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
dot_product = sum(x * y for x, y in zip(a, b))
|
||||
magnitude_a = sum(x * x for x in a) ** 0.5
|
||||
|
|
|
|||
162
llm/embeddings.py
Normal file
162
llm/embeddings.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
from .models import EmbeddingModel
|
||||
from .embeddings_migrations import embeddings_migrations
|
||||
import json
|
||||
from sqlite_utils import Database
|
||||
from typing import Any, Dict, List, Tuple, Optional, Union
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(
|
||||
self,
|
||||
db: Database,
|
||||
name: str,
|
||||
*,
|
||||
model: Optional[EmbeddingModel] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
from llm import get_embedding_model
|
||||
|
||||
self.db = db
|
||||
self.name = name
|
||||
if model and model_id and model.model_id != model_id:
|
||||
raise ValueError("model_id does not match model.model_id")
|
||||
if model_id and not model:
|
||||
model = get_embedding_model(model_id)
|
||||
self.model = model
|
||||
self._id = None
|
||||
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Get the ID of the collection, creating it in the DB if necessary.
|
||||
|
||||
Returns:
|
||||
int: ID of the collection
|
||||
"""
|
||||
if self._id is not None:
|
||||
return self._id
|
||||
if not self.db["collections"].exists():
|
||||
embeddings_migrations.apply(self.db)
|
||||
rows = self.db["collections"].rows_where("name = ?", [self.name])
|
||||
try:
|
||||
row = next(rows)
|
||||
self._id = row["id"]
|
||||
except StopIteration:
|
||||
# Create it
|
||||
self._id = (
|
||||
self.db["collections"]
|
||||
.insert(
|
||||
{
|
||||
"name": self.name,
|
||||
"model": self.model.model_id,
|
||||
}
|
||||
)
|
||||
.last_pk
|
||||
)
|
||||
return self._id
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""
|
||||
Check if the collection exists in the DB.
|
||||
|
||||
Returns:
|
||||
bool: True if exists, False otherwise
|
||||
"""
|
||||
matches = list(
|
||||
self.db.query("select 1 from collections where name = ?", (self.name,))
|
||||
)
|
||||
return bool(matches)
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count the number of items in the collection.
|
||||
|
||||
Returns:
|
||||
int: Number of items in the collection
|
||||
"""
|
||||
return next(
|
||||
self.db.query(
|
||||
"""
|
||||
select count(*) as c from embeddings where collection_id = (
|
||||
select id from collections where name = ?
|
||||
)
|
||||
""",
|
||||
(self.name,),
|
||||
)
|
||||
)["c"]
|
||||
|
||||
def embed(
|
||||
self,
|
||||
id: str,
|
||||
text: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
store: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Embed a text and store it in the collection with a given ID.
|
||||
|
||||
Args:
|
||||
id (str): ID for the text
|
||||
text (str): Text to be embedded
|
||||
metadata (dict, optional): Metadata to be stored
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
"""
|
||||
from llm import encode
|
||||
|
||||
embedding = self.model.embed(text)
|
||||
self.db["embeddings"].insert(
|
||||
{
|
||||
"collection_id": self.id(),
|
||||
"id": id,
|
||||
"embedding": encode(embedding),
|
||||
"content": text if store else None,
|
||||
"metadata": json.dumps(metadata) if metadata else None,
|
||||
}
|
||||
)
|
||||
|
||||
def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
|
||||
"""
|
||||
Embed multiple texts and store them in the collection with given IDs.
|
||||
|
||||
Args:
|
||||
id_text_map (dict): Dictionary mapping IDs to texts
|
||||
store (bool, optional): Whether to store the text in the content column
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_multi_with_metadata(
|
||||
self,
|
||||
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
|
||||
) -> None:
|
||||
"""
|
||||
Embed multiple texts along with metadata and store them in the collection with given IDs.
|
||||
|
||||
Args:
|
||||
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find similar items in the collection by a given ID.
|
||||
|
||||
Args:
|
||||
id (str): ID to search by
|
||||
number (int, optional): Number of similar items to return
|
||||
|
||||
Returns:
|
||||
list: List of (id, score) tuples
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find similar items in the collection by a given text.
|
||||
|
||||
Args:
|
||||
text (str): Text to search by
|
||||
number (int, optional): Number of similar items to return
|
||||
|
||||
Returns:
|
||||
list: List of (id, score) tuples
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import llm
|
||||
import sqlite_utils
|
||||
|
||||
|
||||
def test_demo_plugin():
|
||||
|
|
@ -18,3 +19,32 @@ def test_embed_huge_list():
|
|||
assert first_twos == {(5, 1): 10, (5, 2): 90, (5, 3): 900}
|
||||
# Should have happened in 100 batches
|
||||
assert model.batch_count == 100
|
||||
|
||||
|
||||
def test_collection():
|
||||
db = sqlite_utils.Database(memory=True)
|
||||
collection = llm.Collection(db, "test", model_id="embed-demo")
|
||||
assert collection.id() == 1
|
||||
assert collection.count() == 0
|
||||
# Embed some stuff
|
||||
collection.embed(1, "hello world")
|
||||
collection.embed(2, "goodbye world")
|
||||
assert collection.count() == 2
|
||||
# Check that the embeddings are there
|
||||
rows = list(db["embeddings"].rows)
|
||||
assert rows == [
|
||||
{
|
||||
"collection_id": 1,
|
||||
"id": "1",
|
||||
"embedding": llm.encode([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
"content": None,
|
||||
"metadata": None,
|
||||
},
|
||||
{
|
||||
"collection_id": 1,
|
||||
"id": "2",
|
||||
"embedding": llm.encode([7, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
"content": None,
|
||||
"metadata": None,
|
||||
},
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue