Initial CLI support and plugin hook for embeddings, refs #185

* Embeddings plugin hook + OpenAI implementation
* llm.get_embedding_model(name) function
* llm embed command, for returning embeddings or saving them to SQLite
* Tests using an EmbedDemo embedding model
* llm embed-models list and emeb-models default commands
* llm embed-db path and llm embed-db collections commands
This commit is contained in:
Simon Willison 2023-08-27 22:24:10 -07:00 committed by GitHub
parent cee5b06604
commit 77cf56e54a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 825 additions and 28 deletions

21
docs/embeddings/binary.md Normal file
View file

@ -0,0 +1,21 @@
(embeddings-binary)=
# Binary embedding formats
The default output format of the `llm embed` command is a JSON array of floating point numbers.
LLM stores embeddings in a more space-efficient format: little-endian binary sequences of 32-bit floating point numbers, each represented using 4 bytes.
The following Python functions can be used to convert between the two formats:
```python
import struct
def encode(values):
return struct.pack("<" + "f" * len(values), *values)
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
```
When using `llm embed` directly, the default output format is JSON.
Use `--format blob` for the binary output, `--format hex` for that binary output as hexadecimal and `--format base64` for that binary output encoded using base64.

97
docs/embeddings/cli.md Normal file
View file

@ -0,0 +1,97 @@
(embeddings-cli)=
# Embedding with the CLI
LLM provides command-line utilities for calculating and storing embeddings for pieces of content.
(embeddings-llm-embed)=
## llm embed
The `llm embed` command can be used to calculate embedding vectors for a string of content. These can be returned directly to the terminal, stored in a SQLite database, or both.
### Returning embeddings to the terminal
The simplest way to use this command is to pass content to it using the `-c/--content` option, like this:
```bash
llm embed -c 'This is some content'
```
The command will return a JSON array of floating point numbers directly to the terminal:
```json
[0.123, 0.456, 0.789...]
```
By default it uses the {ref}`default embedding model <embeddings-cli-embed-models-default>`.
Use the `-m/--model` option to specify a different model:
```bash
llm -m sentence-transformers/all-MiniLM-L6-v2 \
-c 'This is some content'
```
See {ref}`embeddings-binary` for options to get back embeddings in formats other than JSON.
### Storing embeddings in SQLite
Embeddings are much more useful if you store them somewhere, so you can calculate similarity scores between different embeddings later on.
LLM includes a concept of a "collection" of embeddings. This is a named object where multiple pieces of content can be stored, each with a unique ID.
The `llm embed` command can store results directly in a named collection like this:
```bash
cat one.txt | llm embed my-files one
```
This will store the embedding for the contents of `one.txt` in the `my-files` collection under the key `one`.
A collection will be created the first time you mention it.
Collections have a fixed embedding model, which is the model that was used for the first embedding stored in that collection.
In the above example this would have been the default embedding model at the time that the command was run.
This example stores the embedding of the string "my happy hound" in a collection called `phrases` under the key `hound` and using the model `ada-002`:
```bash
llm embed -m ada-002 -c 'my happy hound' phrases hound
```
By default, the SQLite database used to store embeddings is the `embeddings.db` in the user content directory managed by LLM.
You can see the path to this directory by running `llm embed-db path`.
You can store embeddings in a different SQLite database by passing a path to it using the `-d/--database` option to `llm embed`. If this file does not exist yet the command will create it:
```bash
llm embed -d my-embeddings.db -c 'my happy hound' phrases hound
```
This creates a database file called `my-embeddings.db` in the current directory.
(embeddings-cli-embed-models-default)=
## llm embed-models default
This command can be used to get and set the default embedding model.
This will return the name of the current default model:
```bash
llm embed-models default
```
You can set a different default like this:
```
llm embed-models default name-of-other-model
```
Any of the supported aliases for a model can be passed to this command.
## llm embed-db collections
To list all of the collections in the embeddings database, run this command:
```bash
llm embed-db collections
```
Add `--json` for JSON output:
```bash
llm embed-db collections --json
```
Add `-d/--database` to specify a different database file:
```bash
llm embed-db collections -d my-embeddings.db
```

21
docs/embeddings/index.md Normal file
View file

@ -0,0 +1,21 @@
(embeddings)=
# Embeddings
Embedding models allow you to take a piece of text - a word, sentence, paragraph or even a whole articles, and convert that into an array of floating point numbers.
This floating point array is called an "embedding vector", and works as a numerical representation of the semantic meaning of the content in a many-multi-dimensional space.
By calculating the distance between embedding vectors, we can identify which content is semantically "nearest" to other content.
This can be used to build features like related article lookups. It can also be used to build semantic search, where a user can search for a phrase and get back results that are semantically similar to that phrase even if they do not share any exact keywords.
LLM supports multiple embedding models through {ref}`plugins <plugins>`. Once installed, an embedding model can be used on the command-line or via the Python API to calculate and store embeddings for content, and then to perform similarity searches against those embeddings.
```{toctree}
---
maxdepth: 3
---
cli
writing-plugins
binary
```

View file

@ -0,0 +1,48 @@
(embeddings-writing-plugins)=
# Writing plugins to add new embedding models
Read the {ref}`plugin tutorial <tutorial-model-plugin>` for details on how to develop and package a plugin.
This page shows an example plugin that implements and registers a new embedding model.
There are two components to an embedding model plugin:
1. An implementation of the `register_embedding_models()` hook, which takes a `register` callback function and calls it to register the new model with the LLM plugin system.
2. A class that extends the `llm.EmbeddingModel` abstract base class.
The only required method on this class is `embed(text)`, which takes a string and returns a list of floating point numbers.
The following example uses the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) package to provide access to the [MiniLM-L6](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) embedding model.
```python
import llm
from sentence_transformers import SentenceTransformer
@llm.hookimpl
def register_embedding_models(register):
model_id = "sentence-transformers/all-MiniLM-L6-v2"
register(SentenceTransformerModel(model_id, model_id, 384), aliases=("all-MiniLM-L6-v2",))
class SentenceTransformerModel(llm.EmbeddingModel):
def __init__(self, model_id, model_name, embedding_size):
self.model_id = model_id
self.model_name = model_name
self.embedding_size = embedding_size
self._model = None
def embed(self, text):
if self._model is None:
self._model = SentenceTransformer(self.model_name)
return list(map(float, self._model.encode([text])[0]))
```
Once installed, the model provided by this plugin can be used with the {ref}`llm embed <embeddings-llm-embed>` command like this:
```bash
cat file.txt | llm embed -m sentence-transformers/all-MiniLM-L6-v2
```
Or via its registered alias like this:
```bash
cat file.txt | llm embed -m all-MiniLM-L6-v2
```

View file

@ -53,16 +53,19 @@ Options:
--help Show this message and exit.
Commands:
prompt* Execute a prompt
aliases Manage model aliases
install Install packages from PyPI into the same environment as LLM
keys Manage stored API keys for different models
logs Tools for exploring logged prompts and responses
models Manage available models
openai Commands for working directly with the OpenAI API
plugins List installed plugins
templates Manage stored prompt templates
uninstall Uninstall Python packages from the LLM environment
prompt* Execute a prompt
aliases Manage model aliases
embed Embed text and store or return the result
embed-db Manage the embeddings database
embed-models Manage available embedding models
install Install packages from PyPI into the same environment as LLM
keys Manage stored API keys for different models
logs Tools for exploring logged prompts and responses
models Manage available models
openai Commands for working directly with the OpenAI API
plugins List installed plugins
templates Manage stored prompt templates
uninstall Uninstall Python packages from the LLM environment
```
### llm prompt --help
```
@ -380,6 +383,86 @@ Options:
-y, --yes Don't ask for confirmation
--help Show this message and exit.
```
### llm embed --help
```
Usage: llm embed [OPTIONS] [COLLECTION] [ID]
Embed text and store or return the result
Options:
-i, --input FILE Content to embed
-m, --model TEXT Embedding model to use
--store Store the text itself in the database
-d, --database FILE
-c, --content FILE
-f, --format [json|blob|base64|hex]
Output format
--help Show this message and exit.
```
### llm embed-models --help
```
Usage: llm embed-models [OPTIONS] COMMAND [ARGS]...
Manage available embedding models
Options:
--help Show this message and exit.
Commands:
list* List available embedding models
default Show or set the default embedding model
```
#### llm embed-models list --help
```
Usage: llm embed-models list [OPTIONS]
List available embedding models
Options:
--help Show this message and exit.
```
#### llm embed-models default --help
```
Usage: llm embed-models default [OPTIONS] [MODEL]
Show or set the default embedding model
Options:
--help Show this message and exit.
```
### llm embed-db --help
```
Usage: llm embed-db [OPTIONS] COMMAND [ARGS]...
Manage the embeddings database
Options:
--help Show this message and exit.
Commands:
collections Output the path to the embeddings database
path Output the path to the embeddings database
```
#### llm embed-db path --help
```
Usage: llm embed-db path [OPTIONS]
Output the path to the embeddings database
Options:
--help Show this message and exit.
```
#### llm embed-db collections --help
```
Usage: llm embed-db collections [OPTIONS]
Output the path to the embeddings database
Options:
-d, --database FILE Path to embeddings database
--json Output as JSON
--help Show this message and exit.
```
### llm openai --help
```
Usage: llm openai [OPTIONS] COMMAND [ARGS]...

View file

@ -57,6 +57,7 @@ maxdepth: 3
setup
usage
other-models
embeddings/index
plugins/index
aliases
python-api

View file

@ -7,6 +7,8 @@ from .models import (
Conversation,
Model,
ModelWithAliases,
EmbeddingModel,
EmbeddingModelWithAliases,
Options,
Prompt,
Response,
@ -73,6 +75,55 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
return model_aliases
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
model_aliases = []
# Include aliases from aliases.json
aliases_path = user_dir() / "aliases.json"
extra_model_aliases: Dict[str, list] = {}
if aliases_path.exists():
configured_aliases = json.loads(aliases_path.read_text())
for alias, model_id in configured_aliases.items():
extra_model_aliases.setdefault(model_id, []).append(alias)
def register(model, aliases=None):
alias_list = list(aliases or [])
if model.model_id in extra_model_aliases:
alias_list.extend(extra_model_aliases[model.model_id])
model_aliases.append(EmbeddingModelWithAliases(model, alias_list))
pm.hook.register_embedding_models(register=register)
return model_aliases
def get_embedding_models():
models = []
def register(model, aliases=None):
models.append(model)
pm.hook.register_embedding_models(register=register)
return models
def get_embedding_model(name):
aliases = get_embedding_model_aliases()
try:
return aliases[name]
except KeyError:
raise UnknownModelError("Unknown model: " + name)
def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
model_aliases = {}
for model_with_aliases in get_embedding_models_with_aliases():
for alias in model_with_aliases.aliases:
model_aliases[alias] = model_with_aliases.model
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
return model_aliases
def get_model_aliases() -> Dict[str, Model]:
model_aliases = {}
for model_with_aliases in get_models_with_aliases():

View file

@ -6,6 +6,8 @@ from llm import (
Response,
Template,
UnknownModelError,
get_embedding_models_with_aliases,
get_embedding_model,
get_key,
get_plugins,
get_model,
@ -17,12 +19,16 @@ from llm import (
)
from .migrations import migrate
from .embeddings_migrations import embeddings_migrations
from .plugins import pm
import base64
import pathlib
import pydantic
from runpy import run_module
import shutil
import sqlite_utils
from sqlite_utils.db import NotFoundError
import struct
import sys
import textwrap
from typing import cast, Optional
@ -32,6 +38,7 @@ import yaml
warnings.simplefilter("ignore", ResourceWarning)
DEFAULT_MODEL = "gpt-3.5-turbo"
DEFAULT_EMBEDDING_MODEL = "ada-002"
DEFAULT_TEMPLATE = "prompt: "
@ -853,6 +860,225 @@ def uninstall(packages, yes):
run_module("pip", run_name="__main__")
@cli.command()
@click.argument("collection", required=False)
@click.argument("id", required=False)
@click.option(
"-i",
"--input",
type=click.Path(file_okay=True, allow_dash=True, dir_okay=False),
help="Content to embed",
)
@click.option("-m", "--model", help="Embedding model to use")
@click.option("--store", is_flag=True, help="Store the text itself in the database")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
)
@click.option(
"-c",
"--content",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
)
@click.option(
"format_",
"-f",
"--format",
type=click.Choice(["json", "blob", "base64", "hex"]),
help="Output format",
)
def embed(collection, id, input, model, store, database, content, format_):
"""Embed text and store or return the result"""
if collection and not id:
raise click.ClickException("Must provide both collection and id")
db = None
def get_db():
if database:
return sqlite_utils.Database(database)
else:
return sqlite_utils.Database(user_dir() / "embeddings.db")
existing_collection = None
if collection:
db = get_db()
if db["collections"].exists():
try:
existing_collection = get_collection(db, collection)
except NotFoundError:
pass
if model is None:
# If collection exists, use that model
if existing_collection:
model = existing_collection["model"]
else:
# Use default model
model = get_default_embedding_model()
if model and existing_collection and model != existing_collection["model"]:
raise click.ClickException(
"Model '{}' does not match '{}' collection model of '{}'".format(
model, collection, existing_collection["model"]
)
)
try:
model = get_embedding_model(model)
except UnknownModelError as ex:
raise click.ClickException(str(ex))
show_output = True
if collection and (format_ is None):
show_output = False
# Resolve input text
if not content:
if not input:
# Read from stdin
input = sys.stdin
content = input.read()
if not content:
raise click.ClickException("No content provided")
embedding = model.embed(content)
if collection:
# Store the embedding
if db is None:
db = get_db()
embeddings_migrations.apply(db)
if not existing_collection:
db["collections"].insert(
{
"name": collection,
"model": model.model_id,
}
)
existing_collection = get_collection(db, collection)
# Now store it
db["embeddings"].insert(
{
"collection_id": existing_collection["id"],
"id": id,
"content": content if store else None,
"embedding": encode(embedding),
},
replace=True,
)
if show_output:
if format_ == "json" or format_ is None:
click.echo(json.dumps(embedding))
elif format_ == "blob":
click.echo(encode(embedding))
elif format_ == "base64":
click.echo(base64.b64encode(encode(embedding)).decode("ascii"))
elif format_ == "hex":
click.echo(encode(embedding).hex())
@cli.group(
cls=DefaultGroup,
default="list",
default_if_no_args=True,
)
def embed_models():
"Manage available embedding models"
@embed_models.command(name="list")
def embed_models_list():
"List available embedding models"
output = []
for model_with_aliases in get_embedding_models_with_aliases():
s = str(model_with_aliases.model.model_id)
if model_with_aliases.aliases:
s += " (aliases: {})".format(", ".join(model_with_aliases.aliases))
output.append(s)
click.echo("\n".join(output))
@embed_models.command(name="default")
@click.argument("model", required=False)
def embed_models_default(model):
"Show or set the default embedding model"
if not model:
click.echo(get_default_embedding_model())
return
# Validate it is a known model
try:
model = get_embedding_model(model)
set_default_embedding_model(model.model_id)
except KeyError:
raise click.ClickException("Unknown embedding model: {}".format(model))
@cli.group()
def embed_db():
"Manage the embeddings database"
@embed_db.command(name="path")
def embed_db_path():
"Output the path to the embeddings database"
click.echo(user_dir() / "embeddings.db")
@embed_db.command(name="collections")
@click.option(
"-d",
"--database",
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
envvar="LLM_EMBEDDINGS_DB",
help="Path to embeddings database",
)
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
def embed_db_collections(database, json_):
"Output the path to the embeddings database"
database = database or (user_dir() / "embeddings.db")
db = sqlite_utils.Database(str(database))
if not db["collections"].exists():
raise click.ClickException("No collections table found in {}".format(database))
rows = db.query(
"""
select
collections.name,
collections.model,
count(embeddings.id) as num_embeddings
from
collections left join embeddings
on collections.id = embeddings.collection_id
group by
collections.name, collections.model
"""
)
if json_:
click.echo(json.dumps(list(rows), indent=4))
else:
for row in rows:
click.echo("{}: {}".format(row["name"], row["model"]))
click.echo(
" {} embedding{}".format(
row["num_embeddings"], "s" if row["num_embeddings"] != 1 else ""
)
)
def get_collection(db, collection):
rows = db["collections"].rows_where("name = ?", [collection])
try:
return next(rows)
except StopIteration:
raise NotFoundError("Collection not found: {}".format(collection))
def template_dir():
path = user_dir() / "templates"
path.mkdir(parents=True, exist_ok=True)
@ -865,19 +1091,27 @@ def _truncate_string(s, max_length=100):
return s
def get_default_model():
path = user_dir() / "default_model.txt"
def get_default_model(filename="default_model.txt", default=DEFAULT_MODEL):
path = user_dir() / filename
if path.exists():
return path.read_text().strip()
else:
return DEFAULT_MODEL
return default
def set_default_model(model):
path = user_dir() / "default_model.txt"
def set_default_model(model, filename="default_model.txt"):
path = user_dir() / filename
path.write_text(model)
def get_default_embedding_model():
return get_default_model("default_embedding_model.txt", DEFAULT_EMBEDDING_MODEL)
def set_default_embedding_model(model):
set_default_model(model, "default_embedding_model.txt")
def logs_db_path():
return user_dir() / "logs.db"
@ -947,3 +1181,11 @@ def _human_readable_size(size_bytes):
def logs_on():
return not (user_dir() / "logs-off").exists()
def encode(values):
return struct.pack("<" + "f" * len(values), *values)
def decode(binary):
return struct.unpack("<" + "f" * (len(binary) // 4), binary)

View file

@ -1,4 +1,4 @@
from llm import Model, hookimpl
from llm import EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string
import click
@ -33,9 +33,18 @@ def register_models(register):
aliases = extra_model.get("aliases", [])
model_name = extra_model["model_name"]
api_base = extra_model.get("api_base")
api_type = extra_model.get("api_type")
api_version = extra_model.get("api_version")
api_engine = extra_model.get("api_engine")
headers = extra_model.get("headers")
chat_model = Chat(
model_id, model_name=model_name, api_base=api_base, headers=headers
model_id,
model_name=model_name,
api_base=api_base,
api_type=api_type,
api_version=api_version,
api_engine=api_engine,
headers=headers,
)
if api_base:
chat_model.needs_key = None
@ -47,6 +56,23 @@ def register_models(register):
)
@hookimpl
def register_embedding_models(register):
register(Ada002(), aliases=("ada",))
class Ada002(EmbeddingModel):
model_id = "ada-002"
embedding_size = 1536
needs_key = "openai"
key_env_var = "OPENAI_API_KEY"
def embed(self, text):
return openai.Embedding.create(
input=text, model="text-embedding-ada-002", api_key=self.get_key()
)["data"][0]["embedding"]
@hookimpl
def register_commands(cli):
@cli.group(name="openai")
@ -179,12 +205,23 @@ class Chat(Model):
return validated_logit_bias
def __init__(
self, model_id, key=None, model_name=None, api_base=None, headers=None
self,
model_id,
key=None,
model_name=None,
api_base=None,
api_type=None,
api_version=None,
api_engine=None,
headers=None,
):
self.model_id = model_id
self.key = key
self.model_name = model_name
self.api_base = api_base
self.api_type = api_type
self.api_version = api_version
self.api_engine = api_engine
self.headers = headers
def __str__(self):
@ -214,6 +251,12 @@ class Chat(Model):
kwargs = dict(not_nulls(prompt.options))
if self.api_base:
kwargs["api_base"] = self.api_base
if self.api_type:
kwargs["api_type"] = self.api_type
if self.api_version:
kwargs["api_version"] = self.api_version
if self.api_engine:
kwargs["engine"] = self.api_engine
if self.needs_key:
if self.key:
kwargs["api_key"] = self.key

View file

@ -0,0 +1,19 @@
from sqlite_migrate import Migrations
embeddings_migrations = Migrations("llm.embeddings")
@embeddings_migrations()
def m001_create_tables(db):
db["collections"].create({"id": int, "name": str, "model": str}, pk="id")
db["collections"].create_index(["name"], unique=True)
db["embeddings"].create(
{
"collection_id": int,
"id": str,
"embedding": bytes,
"content": str,
"metadata": str,
},
pk=("collection_id", "id"),
)

View file

@ -13,3 +13,8 @@ def register_commands(cli):
@hookspec
def register_models(register):
"Return a list of model instances representing LLM models that can be called"
@hookspec
def register_embedding_models(register):
"Return a list of model instances that can be used for embedding"

View file

@ -208,16 +208,7 @@ class Options(BaseModel):
_Options = Options
class Model(ABC):
model_id: str
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
class Options(_Options):
pass
class _get_key_mixin:
def get_key(self):
from llm import get_key
@ -244,6 +235,17 @@ class Model(ABC):
message += " or set the {} environment variable".format(self.key_env_var)
raise NeedsKeyException(message)
class Model(ABC, _get_key_mixin):
model_id: str
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
can_stream: bool = False
class Options(_Options):
pass
def conversation(self):
return Conversation(model=self)
@ -283,12 +285,33 @@ class Model(ABC):
return "<Model '{}'>".format(self.model_id)
class EmbeddingModel(ABC, _get_key_mixin):
model_id: str
embedding_size: int
key: Optional[str] = None
needs_key: Optional[str] = None
key_env_var: Optional[str] = None
@abstractmethod
def embed(self, text: str) -> List[float]:
"""
Embed a some text as a list of floats
"""
pass
@dataclass
class ModelWithAliases:
model: Model
aliases: Set[str]
@dataclass
class EmbeddingModelWithAliases:
model: EmbeddingModel
aliases: Set[str]
def _conversation_name(text):
# Collapse whitespace, including newlines
text = re.sub(r"\s+", " ", text)

View file

@ -6,3 +6,5 @@ ignore_missing_imports = True
[mypy-click_default_group.*]
ignore_missing_imports = True
[mypy-sqlite_migrate.*]
ignore_missing_imports = True

View file

@ -40,6 +40,7 @@ setup(
"openai",
"click-default-group-wheel",
"sqlite-utils>=3.35.0",
"sqlite-migrate",
"pydantic>=1.10.2",
"PyYAML",
"pluggy",

View file

@ -1,4 +1,6 @@
import pytest
import llm
from llm.plugins import pm
def pytest_configure(config):
@ -26,6 +28,33 @@ def env_setup(monkeypatch, user_path):
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
def embed(self, text):
words = text.split()[:16]
embedding = [len(word) for word in words]
# Pad with 0 up to 16 words
embedding += [0] * (16 - len(embedding))
return embedding
@pytest.fixture(autouse=True)
def register_embed_demo_model():
class EmbedDemoPlugin:
__name__ = "EmbedDemoPlugin"
@llm.hookimpl
def register_embedding_models(self, register):
register(EmbedDemo())
pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin")
try:
yield
finally:
pm.unregister(name="undo-embed-demo-plugin")
@pytest.fixture
def mocked_openai(requests_mock):
return requests_mock.post(

6
tests/test_embed.py Normal file
View file

@ -0,0 +1,6 @@
import llm
def test_demo_plugin():
model = llm.get_embedding_model("embed-demo")
assert model.embed("hello world") == [5, 5] + [0] * 14

105
tests/test_embed_cli.py Normal file
View file

@ -0,0 +1,105 @@
from click.testing import CliRunner
from llm.cli import cli
import json
import pytest
import sqlite_utils
@pytest.mark.parametrize(
"format_,expected",
(
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
(
"base64",
(
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
),
),
(
"hex",
(
"0000a0400000a04000000000000000000000000000000000000000000"
"000000000000000000000000000000000000000000000000000000000"
"00000000000000\n"
),
),
(
"blob",
(
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
).decode("utf-8"),
),
),
)
def test_embed_output_format(format_, expected):
runner = CliRunner()
result = runner.invoke(
cli, ["embed", "--format", format_, "-c", "hello world", "-m", "embed-demo"]
)
assert result.exit_code == 0
assert result.output == expected
@pytest.mark.parametrize(
"args,expected_error",
((["-c", "Content", "stories"], "Must provide both collection and id"),),
)
def test_embed_errors(args, expected_error):
runner = CliRunner()
result = runner.invoke(cli, ["embed"] + args)
assert result.exit_code == 1
assert expected_error in result.output
def test_embed_store(user_path):
embeddings_db = user_path / "embeddings.db"
assert not embeddings_db.exists()
runner = CliRunner()
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
assert result.exit_code == 0
# Should not have created the table
assert not embeddings_db.exists()
# Now run it to store
result = runner.invoke(
cli, ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
)
assert result.exit_code == 0
assert embeddings_db.exists()
# Check the contents
db = sqlite_utils.Database(str(embeddings_db))
assert list(db["collections"].rows) == [
{"id": 1, "name": "items", "model": "embed-demo"}
]
assert list(db["embeddings"].rows) == [
{
"collection_id": 1,
"id": "1",
"embedding": (
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
b"\x00\x00\x00\x00\x00\x00\x00"
),
"content": None,
"metadata": None,
}
]
# Should show up in 'llm embed-db collections'
for is_json in (False, True):
args = ["embed-db", "collections"]
if is_json:
args.extend(["--json"])
result2 = runner.invoke(cli, args)
assert result2.exit_code == 0
if is_json:
assert json.loads(result2.output) == [
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
]
else:
assert result2.output == "items: embed-demo\n 1 embedding\n"