mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-16 20:50:25 +00:00
Initial CLI support and plugin hook for embeddings, refs #185
* Embeddings plugin hook + OpenAI implementation * llm.get_embedding_model(name) function * llm embed command, for returning embeddings or saving them to SQLite * Tests using an EmbedDemo embedding model * llm embed-models list and emeb-models default commands * llm embed-db path and llm embed-db collections commands
This commit is contained in:
parent
cee5b06604
commit
77cf56e54a
17 changed files with 825 additions and 28 deletions
21
docs/embeddings/binary.md
Normal file
21
docs/embeddings/binary.md
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
(embeddings-binary)=
|
||||
# Binary embedding formats
|
||||
|
||||
The default output format of the `llm embed` command is a JSON array of floating point numbers.
|
||||
|
||||
LLM stores embeddings in a more space-efficient format: little-endian binary sequences of 32-bit floating point numbers, each represented using 4 bytes.
|
||||
|
||||
The following Python functions can be used to convert between the two formats:
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def encode(values):
|
||||
return struct.pack("<" + "f" * len(values), *values)
|
||||
|
||||
def decode(binary):
|
||||
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
|
||||
```
|
||||
When using `llm embed` directly, the default output format is JSON.
|
||||
|
||||
Use `--format blob` for the binary output, `--format hex` for that binary output as hexadecimal and `--format base64` for that binary output encoded using base64.
|
||||
97
docs/embeddings/cli.md
Normal file
97
docs/embeddings/cli.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
(embeddings-cli)=
|
||||
# Embedding with the CLI
|
||||
|
||||
LLM provides command-line utilities for calculating and storing embeddings for pieces of content.
|
||||
|
||||
(embeddings-llm-embed)=
|
||||
## llm embed
|
||||
|
||||
The `llm embed` command can be used to calculate embedding vectors for a string of content. These can be returned directly to the terminal, stored in a SQLite database, or both.
|
||||
|
||||
### Returning embeddings to the terminal
|
||||
|
||||
The simplest way to use this command is to pass content to it using the `-c/--content` option, like this:
|
||||
|
||||
```bash
|
||||
llm embed -c 'This is some content'
|
||||
```
|
||||
The command will return a JSON array of floating point numbers directly to the terminal:
|
||||
|
||||
```json
|
||||
[0.123, 0.456, 0.789...]
|
||||
```
|
||||
By default it uses the {ref}`default embedding model <embeddings-cli-embed-models-default>`.
|
||||
|
||||
Use the `-m/--model` option to specify a different model:
|
||||
|
||||
```bash
|
||||
llm -m sentence-transformers/all-MiniLM-L6-v2 \
|
||||
-c 'This is some content'
|
||||
```
|
||||
See {ref}`embeddings-binary` for options to get back embeddings in formats other than JSON.
|
||||
|
||||
### Storing embeddings in SQLite
|
||||
|
||||
Embeddings are much more useful if you store them somewhere, so you can calculate similarity scores between different embeddings later on.
|
||||
|
||||
LLM includes a concept of a "collection" of embeddings. This is a named object where multiple pieces of content can be stored, each with a unique ID.
|
||||
|
||||
The `llm embed` command can store results directly in a named collection like this:
|
||||
|
||||
```bash
|
||||
cat one.txt | llm embed my-files one
|
||||
```
|
||||
This will store the embedding for the contents of `one.txt` in the `my-files` collection under the key `one`.
|
||||
|
||||
A collection will be created the first time you mention it.
|
||||
|
||||
Collections have a fixed embedding model, which is the model that was used for the first embedding stored in that collection.
|
||||
|
||||
In the above example this would have been the default embedding model at the time that the command was run.
|
||||
|
||||
This example stores the embedding of the string "my happy hound" in a collection called `phrases` under the key `hound` and using the model `ada-002`:
|
||||
|
||||
```bash
|
||||
llm embed -m ada-002 -c 'my happy hound' phrases hound
|
||||
```
|
||||
By default, the SQLite database used to store embeddings is the `embeddings.db` in the user content directory managed by LLM.
|
||||
|
||||
You can see the path to this directory by running `llm embed-db path`.
|
||||
|
||||
You can store embeddings in a different SQLite database by passing a path to it using the `-d/--database` option to `llm embed`. If this file does not exist yet the command will create it:
|
||||
|
||||
```bash
|
||||
llm embed -d my-embeddings.db -c 'my happy hound' phrases hound
|
||||
```
|
||||
This creates a database file called `my-embeddings.db` in the current directory.
|
||||
|
||||
(embeddings-cli-embed-models-default)=
|
||||
## llm embed-models default
|
||||
|
||||
This command can be used to get and set the default embedding model.
|
||||
|
||||
This will return the name of the current default model:
|
||||
```bash
|
||||
llm embed-models default
|
||||
```
|
||||
You can set a different default like this:
|
||||
```
|
||||
llm embed-models default name-of-other-model
|
||||
```
|
||||
Any of the supported aliases for a model can be passed to this command.
|
||||
|
||||
## llm embed-db collections
|
||||
|
||||
To list all of the collections in the embeddings database, run this command:
|
||||
|
||||
```bash
|
||||
llm embed-db collections
|
||||
```
|
||||
Add `--json` for JSON output:
|
||||
```bash
|
||||
llm embed-db collections --json
|
||||
```
|
||||
Add `-d/--database` to specify a different database file:
|
||||
```bash
|
||||
llm embed-db collections -d my-embeddings.db
|
||||
```
|
||||
21
docs/embeddings/index.md
Normal file
21
docs/embeddings/index.md
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
(embeddings)=
|
||||
# Embeddings
|
||||
|
||||
Embedding models allow you to take a piece of text - a word, sentence, paragraph or even a whole articles, and convert that into an array of floating point numbers.
|
||||
|
||||
This floating point array is called an "embedding vector", and works as a numerical representation of the semantic meaning of the content in a many-multi-dimensional space.
|
||||
|
||||
By calculating the distance between embedding vectors, we can identify which content is semantically "nearest" to other content.
|
||||
|
||||
This can be used to build features like related article lookups. It can also be used to build semantic search, where a user can search for a phrase and get back results that are semantically similar to that phrase even if they do not share any exact keywords.
|
||||
|
||||
LLM supports multiple embedding models through {ref}`plugins <plugins>`. Once installed, an embedding model can be used on the command-line or via the Python API to calculate and store embeddings for content, and then to perform similarity searches against those embeddings.
|
||||
|
||||
```{toctree}
|
||||
---
|
||||
maxdepth: 3
|
||||
---
|
||||
cli
|
||||
writing-plugins
|
||||
binary
|
||||
```
|
||||
48
docs/embeddings/writing-plugins.md
Normal file
48
docs/embeddings/writing-plugins.md
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
(embeddings-writing-plugins)=
|
||||
# Writing plugins to add new embedding models
|
||||
|
||||
Read the {ref}`plugin tutorial <tutorial-model-plugin>` for details on how to develop and package a plugin.
|
||||
|
||||
This page shows an example plugin that implements and registers a new embedding model.
|
||||
|
||||
There are two components to an embedding model plugin:
|
||||
|
||||
1. An implementation of the `register_embedding_models()` hook, which takes a `register` callback function and calls it to register the new model with the LLM plugin system.
|
||||
2. A class that extends the `llm.EmbeddingModel` abstract base class.
|
||||
|
||||
The only required method on this class is `embed(text)`, which takes a string and returns a list of floating point numbers.
|
||||
|
||||
The following example uses the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) package to provide access to the [MiniLM-L6](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) embedding model.
|
||||
|
||||
```python
|
||||
import llm
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
@llm.hookimpl
|
||||
def register_embedding_models(register):
|
||||
model_id = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
register(SentenceTransformerModel(model_id, model_id, 384), aliases=("all-MiniLM-L6-v2",))
|
||||
|
||||
|
||||
class SentenceTransformerModel(llm.EmbeddingModel):
|
||||
def __init__(self, model_id, model_name, embedding_size):
|
||||
self.model_id = model_id
|
||||
self.model_name = model_name
|
||||
self.embedding_size = embedding_size
|
||||
self._model = None
|
||||
|
||||
def embed(self, text):
|
||||
if self._model is None:
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
return list(map(float, self._model.encode([text])[0]))
|
||||
```
|
||||
Once installed, the model provided by this plugin can be used with the {ref}`llm embed <embeddings-llm-embed>` command like this:
|
||||
|
||||
```bash
|
||||
cat file.txt | llm embed -m sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
Or via its registered alias like this:
|
||||
```bash
|
||||
cat file.txt | llm embed -m all-MiniLM-L6-v2
|
||||
```
|
||||
103
docs/help.md
103
docs/help.md
|
|
@ -53,16 +53,19 @@ Options:
|
|||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
prompt* Execute a prompt
|
||||
aliases Manage model aliases
|
||||
install Install packages from PyPI into the same environment as LLM
|
||||
keys Manage stored API keys for different models
|
||||
logs Tools for exploring logged prompts and responses
|
||||
models Manage available models
|
||||
openai Commands for working directly with the OpenAI API
|
||||
plugins List installed plugins
|
||||
templates Manage stored prompt templates
|
||||
uninstall Uninstall Python packages from the LLM environment
|
||||
prompt* Execute a prompt
|
||||
aliases Manage model aliases
|
||||
embed Embed text and store or return the result
|
||||
embed-db Manage the embeddings database
|
||||
embed-models Manage available embedding models
|
||||
install Install packages from PyPI into the same environment as LLM
|
||||
keys Manage stored API keys for different models
|
||||
logs Tools for exploring logged prompts and responses
|
||||
models Manage available models
|
||||
openai Commands for working directly with the OpenAI API
|
||||
plugins List installed plugins
|
||||
templates Manage stored prompt templates
|
||||
uninstall Uninstall Python packages from the LLM environment
|
||||
```
|
||||
### llm prompt --help
|
||||
```
|
||||
|
|
@ -380,6 +383,86 @@ Options:
|
|||
-y, --yes Don't ask for confirmation
|
||||
--help Show this message and exit.
|
||||
```
|
||||
### llm embed --help
|
||||
```
|
||||
Usage: llm embed [OPTIONS] [COLLECTION] [ID]
|
||||
|
||||
Embed text and store or return the result
|
||||
|
||||
Options:
|
||||
-i, --input FILE Content to embed
|
||||
-m, --model TEXT Embedding model to use
|
||||
--store Store the text itself in the database
|
||||
-d, --database FILE
|
||||
-c, --content FILE
|
||||
-f, --format [json|blob|base64|hex]
|
||||
Output format
|
||||
--help Show this message and exit.
|
||||
```
|
||||
### llm embed-models --help
|
||||
```
|
||||
Usage: llm embed-models [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Manage available embedding models
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
list* List available embedding models
|
||||
default Show or set the default embedding model
|
||||
```
|
||||
#### llm embed-models list --help
|
||||
```
|
||||
Usage: llm embed-models list [OPTIONS]
|
||||
|
||||
List available embedding models
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
```
|
||||
#### llm embed-models default --help
|
||||
```
|
||||
Usage: llm embed-models default [OPTIONS] [MODEL]
|
||||
|
||||
Show or set the default embedding model
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
```
|
||||
### llm embed-db --help
|
||||
```
|
||||
Usage: llm embed-db [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Manage the embeddings database
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
collections Output the path to the embeddings database
|
||||
path Output the path to the embeddings database
|
||||
```
|
||||
#### llm embed-db path --help
|
||||
```
|
||||
Usage: llm embed-db path [OPTIONS]
|
||||
|
||||
Output the path to the embeddings database
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
```
|
||||
#### llm embed-db collections --help
|
||||
```
|
||||
Usage: llm embed-db collections [OPTIONS]
|
||||
|
||||
Output the path to the embeddings database
|
||||
|
||||
Options:
|
||||
-d, --database FILE Path to embeddings database
|
||||
--json Output as JSON
|
||||
--help Show this message and exit.
|
||||
```
|
||||
### llm openai --help
|
||||
```
|
||||
Usage: llm openai [OPTIONS] COMMAND [ARGS]...
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ maxdepth: 3
|
|||
setup
|
||||
usage
|
||||
other-models
|
||||
embeddings/index
|
||||
plugins/index
|
||||
aliases
|
||||
python-api
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from .models import (
|
|||
Conversation,
|
||||
Model,
|
||||
ModelWithAliases,
|
||||
EmbeddingModel,
|
||||
EmbeddingModelWithAliases,
|
||||
Options,
|
||||
Prompt,
|
||||
Response,
|
||||
|
|
@ -73,6 +75,55 @@ def get_models_with_aliases() -> List["ModelWithAliases"]:
|
|||
return model_aliases
|
||||
|
||||
|
||||
def get_embedding_models_with_aliases() -> List["EmbeddingModelWithAliases"]:
|
||||
model_aliases = []
|
||||
|
||||
# Include aliases from aliases.json
|
||||
aliases_path = user_dir() / "aliases.json"
|
||||
extra_model_aliases: Dict[str, list] = {}
|
||||
if aliases_path.exists():
|
||||
configured_aliases = json.loads(aliases_path.read_text())
|
||||
for alias, model_id in configured_aliases.items():
|
||||
extra_model_aliases.setdefault(model_id, []).append(alias)
|
||||
|
||||
def register(model, aliases=None):
|
||||
alias_list = list(aliases or [])
|
||||
if model.model_id in extra_model_aliases:
|
||||
alias_list.extend(extra_model_aliases[model.model_id])
|
||||
model_aliases.append(EmbeddingModelWithAliases(model, alias_list))
|
||||
|
||||
pm.hook.register_embedding_models(register=register)
|
||||
|
||||
return model_aliases
|
||||
|
||||
|
||||
def get_embedding_models():
|
||||
models = []
|
||||
|
||||
def register(model, aliases=None):
|
||||
models.append(model)
|
||||
|
||||
pm.hook.register_embedding_models(register=register)
|
||||
return models
|
||||
|
||||
|
||||
def get_embedding_model(name):
|
||||
aliases = get_embedding_model_aliases()
|
||||
try:
|
||||
return aliases[name]
|
||||
except KeyError:
|
||||
raise UnknownModelError("Unknown model: " + name)
|
||||
|
||||
|
||||
def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]:
|
||||
model_aliases = {}
|
||||
for model_with_aliases in get_embedding_models_with_aliases():
|
||||
for alias in model_with_aliases.aliases:
|
||||
model_aliases[alias] = model_with_aliases.model
|
||||
model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model
|
||||
return model_aliases
|
||||
|
||||
|
||||
def get_model_aliases() -> Dict[str, Model]:
|
||||
model_aliases = {}
|
||||
for model_with_aliases in get_models_with_aliases():
|
||||
|
|
|
|||
252
llm/cli.py
252
llm/cli.py
|
|
@ -6,6 +6,8 @@ from llm import (
|
|||
Response,
|
||||
Template,
|
||||
UnknownModelError,
|
||||
get_embedding_models_with_aliases,
|
||||
get_embedding_model,
|
||||
get_key,
|
||||
get_plugins,
|
||||
get_model,
|
||||
|
|
@ -17,12 +19,16 @@ from llm import (
|
|||
)
|
||||
|
||||
from .migrations import migrate
|
||||
from .embeddings_migrations import embeddings_migrations
|
||||
from .plugins import pm
|
||||
import base64
|
||||
import pathlib
|
||||
import pydantic
|
||||
from runpy import run_module
|
||||
import shutil
|
||||
import sqlite_utils
|
||||
from sqlite_utils.db import NotFoundError
|
||||
import struct
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import cast, Optional
|
||||
|
|
@ -32,6 +38,7 @@ import yaml
|
|||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||
DEFAULT_EMBEDDING_MODEL = "ada-002"
|
||||
|
||||
DEFAULT_TEMPLATE = "prompt: "
|
||||
|
||||
|
|
@ -853,6 +860,225 @@ def uninstall(packages, yes):
|
|||
run_module("pip", run_name="__main__")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("collection", required=False)
|
||||
@click.argument("id", required=False)
|
||||
@click.option(
|
||||
"-i",
|
||||
"--input",
|
||||
type=click.Path(file_okay=True, allow_dash=True, dir_okay=False),
|
||||
help="Content to embed",
|
||||
)
|
||||
@click.option("-m", "--model", help="Embedding model to use")
|
||||
@click.option("--store", is_flag=True, help="Store the text itself in the database")
|
||||
@click.option(
|
||||
"-d",
|
||||
"--database",
|
||||
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
|
||||
envvar="LLM_EMBEDDINGS_DB",
|
||||
)
|
||||
@click.option(
|
||||
"-c",
|
||||
"--content",
|
||||
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
|
||||
)
|
||||
@click.option(
|
||||
"format_",
|
||||
"-f",
|
||||
"--format",
|
||||
type=click.Choice(["json", "blob", "base64", "hex"]),
|
||||
help="Output format",
|
||||
)
|
||||
def embed(collection, id, input, model, store, database, content, format_):
|
||||
"""Embed text and store or return the result"""
|
||||
if collection and not id:
|
||||
raise click.ClickException("Must provide both collection and id")
|
||||
|
||||
db = None
|
||||
|
||||
def get_db():
|
||||
if database:
|
||||
return sqlite_utils.Database(database)
|
||||
else:
|
||||
return sqlite_utils.Database(user_dir() / "embeddings.db")
|
||||
|
||||
existing_collection = None
|
||||
if collection:
|
||||
db = get_db()
|
||||
if db["collections"].exists():
|
||||
try:
|
||||
existing_collection = get_collection(db, collection)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
if model is None:
|
||||
# If collection exists, use that model
|
||||
if existing_collection:
|
||||
model = existing_collection["model"]
|
||||
else:
|
||||
# Use default model
|
||||
model = get_default_embedding_model()
|
||||
|
||||
if model and existing_collection and model != existing_collection["model"]:
|
||||
raise click.ClickException(
|
||||
"Model '{}' does not match '{}' collection model of '{}'".format(
|
||||
model, collection, existing_collection["model"]
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
model = get_embedding_model(model)
|
||||
except UnknownModelError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
|
||||
show_output = True
|
||||
if collection and (format_ is None):
|
||||
show_output = False
|
||||
|
||||
# Resolve input text
|
||||
if not content:
|
||||
if not input:
|
||||
# Read from stdin
|
||||
input = sys.stdin
|
||||
content = input.read()
|
||||
if not content:
|
||||
raise click.ClickException("No content provided")
|
||||
|
||||
embedding = model.embed(content)
|
||||
|
||||
if collection:
|
||||
# Store the embedding
|
||||
if db is None:
|
||||
db = get_db()
|
||||
|
||||
embeddings_migrations.apply(db)
|
||||
|
||||
if not existing_collection:
|
||||
db["collections"].insert(
|
||||
{
|
||||
"name": collection,
|
||||
"model": model.model_id,
|
||||
}
|
||||
)
|
||||
existing_collection = get_collection(db, collection)
|
||||
|
||||
# Now store it
|
||||
db["embeddings"].insert(
|
||||
{
|
||||
"collection_id": existing_collection["id"],
|
||||
"id": id,
|
||||
"content": content if store else None,
|
||||
"embedding": encode(embedding),
|
||||
},
|
||||
replace=True,
|
||||
)
|
||||
|
||||
if show_output:
|
||||
if format_ == "json" or format_ is None:
|
||||
click.echo(json.dumps(embedding))
|
||||
elif format_ == "blob":
|
||||
click.echo(encode(embedding))
|
||||
elif format_ == "base64":
|
||||
click.echo(base64.b64encode(encode(embedding)).decode("ascii"))
|
||||
elif format_ == "hex":
|
||||
click.echo(encode(embedding).hex())
|
||||
|
||||
|
||||
@cli.group(
|
||||
cls=DefaultGroup,
|
||||
default="list",
|
||||
default_if_no_args=True,
|
||||
)
|
||||
def embed_models():
|
||||
"Manage available embedding models"
|
||||
|
||||
|
||||
@embed_models.command(name="list")
|
||||
def embed_models_list():
|
||||
"List available embedding models"
|
||||
output = []
|
||||
for model_with_aliases in get_embedding_models_with_aliases():
|
||||
s = str(model_with_aliases.model.model_id)
|
||||
if model_with_aliases.aliases:
|
||||
s += " (aliases: {})".format(", ".join(model_with_aliases.aliases))
|
||||
output.append(s)
|
||||
click.echo("\n".join(output))
|
||||
|
||||
|
||||
@embed_models.command(name="default")
|
||||
@click.argument("model", required=False)
|
||||
def embed_models_default(model):
|
||||
"Show or set the default embedding model"
|
||||
if not model:
|
||||
click.echo(get_default_embedding_model())
|
||||
return
|
||||
# Validate it is a known model
|
||||
try:
|
||||
model = get_embedding_model(model)
|
||||
set_default_embedding_model(model.model_id)
|
||||
except KeyError:
|
||||
raise click.ClickException("Unknown embedding model: {}".format(model))
|
||||
|
||||
|
||||
@cli.group()
|
||||
def embed_db():
|
||||
"Manage the embeddings database"
|
||||
|
||||
|
||||
@embed_db.command(name="path")
|
||||
def embed_db_path():
|
||||
"Output the path to the embeddings database"
|
||||
click.echo(user_dir() / "embeddings.db")
|
||||
|
||||
|
||||
@embed_db.command(name="collections")
|
||||
@click.option(
|
||||
"-d",
|
||||
"--database",
|
||||
type=click.Path(file_okay=True, allow_dash=False, dir_okay=False, writable=True),
|
||||
envvar="LLM_EMBEDDINGS_DB",
|
||||
help="Path to embeddings database",
|
||||
)
|
||||
@click.option("json_", "--json", is_flag=True, help="Output as JSON")
|
||||
def embed_db_collections(database, json_):
|
||||
"Output the path to the embeddings database"
|
||||
database = database or (user_dir() / "embeddings.db")
|
||||
db = sqlite_utils.Database(str(database))
|
||||
if not db["collections"].exists():
|
||||
raise click.ClickException("No collections table found in {}".format(database))
|
||||
rows = db.query(
|
||||
"""
|
||||
select
|
||||
collections.name,
|
||||
collections.model,
|
||||
count(embeddings.id) as num_embeddings
|
||||
from
|
||||
collections left join embeddings
|
||||
on collections.id = embeddings.collection_id
|
||||
group by
|
||||
collections.name, collections.model
|
||||
"""
|
||||
)
|
||||
if json_:
|
||||
click.echo(json.dumps(list(rows), indent=4))
|
||||
else:
|
||||
for row in rows:
|
||||
click.echo("{}: {}".format(row["name"], row["model"]))
|
||||
click.echo(
|
||||
" {} embedding{}".format(
|
||||
row["num_embeddings"], "s" if row["num_embeddings"] != 1 else ""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_collection(db, collection):
|
||||
rows = db["collections"].rows_where("name = ?", [collection])
|
||||
try:
|
||||
return next(rows)
|
||||
except StopIteration:
|
||||
raise NotFoundError("Collection not found: {}".format(collection))
|
||||
|
||||
|
||||
def template_dir():
|
||||
path = user_dir() / "templates"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -865,19 +1091,27 @@ def _truncate_string(s, max_length=100):
|
|||
return s
|
||||
|
||||
|
||||
def get_default_model():
|
||||
path = user_dir() / "default_model.txt"
|
||||
def get_default_model(filename="default_model.txt", default=DEFAULT_MODEL):
|
||||
path = user_dir() / filename
|
||||
if path.exists():
|
||||
return path.read_text().strip()
|
||||
else:
|
||||
return DEFAULT_MODEL
|
||||
return default
|
||||
|
||||
|
||||
def set_default_model(model):
|
||||
path = user_dir() / "default_model.txt"
|
||||
def set_default_model(model, filename="default_model.txt"):
|
||||
path = user_dir() / filename
|
||||
path.write_text(model)
|
||||
|
||||
|
||||
def get_default_embedding_model():
|
||||
return get_default_model("default_embedding_model.txt", DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
def set_default_embedding_model(model):
|
||||
set_default_model(model, "default_embedding_model.txt")
|
||||
|
||||
|
||||
def logs_db_path():
|
||||
return user_dir() / "logs.db"
|
||||
|
||||
|
|
@ -947,3 +1181,11 @@ def _human_readable_size(size_bytes):
|
|||
|
||||
def logs_on():
|
||||
return not (user_dir() / "logs-off").exists()
|
||||
|
||||
|
||||
def encode(values):
|
||||
return struct.pack("<" + "f" * len(values), *values)
|
||||
|
||||
|
||||
def decode(binary):
|
||||
return struct.unpack("<" + "f" * (len(binary) // 4), binary)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from llm import Model, hookimpl
|
||||
from llm import EmbeddingModel, Model, hookimpl
|
||||
import llm
|
||||
from llm.utils import dicts_to_table_string
|
||||
import click
|
||||
|
|
@ -33,9 +33,18 @@ def register_models(register):
|
|||
aliases = extra_model.get("aliases", [])
|
||||
model_name = extra_model["model_name"]
|
||||
api_base = extra_model.get("api_base")
|
||||
api_type = extra_model.get("api_type")
|
||||
api_version = extra_model.get("api_version")
|
||||
api_engine = extra_model.get("api_engine")
|
||||
headers = extra_model.get("headers")
|
||||
chat_model = Chat(
|
||||
model_id, model_name=model_name, api_base=api_base, headers=headers
|
||||
model_id,
|
||||
model_name=model_name,
|
||||
api_base=api_base,
|
||||
api_type=api_type,
|
||||
api_version=api_version,
|
||||
api_engine=api_engine,
|
||||
headers=headers,
|
||||
)
|
||||
if api_base:
|
||||
chat_model.needs_key = None
|
||||
|
|
@ -47,6 +56,23 @@ def register_models(register):
|
|||
)
|
||||
|
||||
|
||||
@hookimpl
|
||||
def register_embedding_models(register):
|
||||
register(Ada002(), aliases=("ada",))
|
||||
|
||||
|
||||
class Ada002(EmbeddingModel):
|
||||
model_id = "ada-002"
|
||||
embedding_size = 1536
|
||||
needs_key = "openai"
|
||||
key_env_var = "OPENAI_API_KEY"
|
||||
|
||||
def embed(self, text):
|
||||
return openai.Embedding.create(
|
||||
input=text, model="text-embedding-ada-002", api_key=self.get_key()
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
|
||||
@hookimpl
|
||||
def register_commands(cli):
|
||||
@cli.group(name="openai")
|
||||
|
|
@ -179,12 +205,23 @@ class Chat(Model):
|
|||
return validated_logit_bias
|
||||
|
||||
def __init__(
|
||||
self, model_id, key=None, model_name=None, api_base=None, headers=None
|
||||
self,
|
||||
model_id,
|
||||
key=None,
|
||||
model_name=None,
|
||||
api_base=None,
|
||||
api_type=None,
|
||||
api_version=None,
|
||||
api_engine=None,
|
||||
headers=None,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.key = key
|
||||
self.model_name = model_name
|
||||
self.api_base = api_base
|
||||
self.api_type = api_type
|
||||
self.api_version = api_version
|
||||
self.api_engine = api_engine
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
|
|
@ -214,6 +251,12 @@ class Chat(Model):
|
|||
kwargs = dict(not_nulls(prompt.options))
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.api_type:
|
||||
kwargs["api_type"] = self.api_type
|
||||
if self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_engine:
|
||||
kwargs["engine"] = self.api_engine
|
||||
if self.needs_key:
|
||||
if self.key:
|
||||
kwargs["api_key"] = self.key
|
||||
|
|
|
|||
19
llm/embeddings_migrations.py
Normal file
19
llm/embeddings_migrations.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from sqlite_migrate import Migrations
|
||||
|
||||
embeddings_migrations = Migrations("llm.embeddings")
|
||||
|
||||
|
||||
@embeddings_migrations()
|
||||
def m001_create_tables(db):
|
||||
db["collections"].create({"id": int, "name": str, "model": str}, pk="id")
|
||||
db["collections"].create_index(["name"], unique=True)
|
||||
db["embeddings"].create(
|
||||
{
|
||||
"collection_id": int,
|
||||
"id": str,
|
||||
"embedding": bytes,
|
||||
"content": str,
|
||||
"metadata": str,
|
||||
},
|
||||
pk=("collection_id", "id"),
|
||||
)
|
||||
|
|
@ -13,3 +13,8 @@ def register_commands(cli):
|
|||
@hookspec
|
||||
def register_models(register):
|
||||
"Return a list of model instances representing LLM models that can be called"
|
||||
|
||||
|
||||
@hookspec
|
||||
def register_embedding_models(register):
|
||||
"Return a list of model instances that can be used for embedding"
|
||||
|
|
|
|||
|
|
@ -208,16 +208,7 @@ class Options(BaseModel):
|
|||
_Options = Options
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
model_id: str
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
can_stream: bool = False
|
||||
|
||||
class Options(_Options):
|
||||
pass
|
||||
|
||||
class _get_key_mixin:
|
||||
def get_key(self):
|
||||
from llm import get_key
|
||||
|
||||
|
|
@ -244,6 +235,17 @@ class Model(ABC):
|
|||
message += " or set the {} environment variable".format(self.key_env_var)
|
||||
raise NeedsKeyException(message)
|
||||
|
||||
|
||||
class Model(ABC, _get_key_mixin):
|
||||
model_id: str
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
can_stream: bool = False
|
||||
|
||||
class Options(_Options):
|
||||
pass
|
||||
|
||||
def conversation(self):
|
||||
return Conversation(model=self)
|
||||
|
||||
|
|
@ -283,12 +285,33 @@ class Model(ABC):
|
|||
return "<Model '{}'>".format(self.model_id)
|
||||
|
||||
|
||||
class EmbeddingModel(ABC, _get_key_mixin):
|
||||
model_id: str
|
||||
embedding_size: int
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""
|
||||
Embed a some text as a list of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWithAliases:
|
||||
model: Model
|
||||
aliases: Set[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelWithAliases:
|
||||
model: EmbeddingModel
|
||||
aliases: Set[str]
|
||||
|
||||
|
||||
def _conversation_name(text):
|
||||
# Collapse whitespace, including newlines
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
|
|
|
|||
2
mypy.ini
2
mypy.ini
|
|
@ -6,3 +6,5 @@ ignore_missing_imports = True
|
|||
[mypy-click_default_group.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sqlite_migrate.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -40,6 +40,7 @@ setup(
|
|||
"openai",
|
||||
"click-default-group-wheel",
|
||||
"sqlite-utils>=3.35.0",
|
||||
"sqlite-migrate",
|
||||
"pydantic>=1.10.2",
|
||||
"PyYAML",
|
||||
"pluggy",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import pytest
|
||||
import llm
|
||||
from llm.plugins import pm
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
@ -26,6 +28,33 @@ def env_setup(monkeypatch, user_path):
|
|||
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
|
||||
|
||||
|
||||
class EmbedDemo(llm.EmbeddingModel):
|
||||
model_id = "embed-demo"
|
||||
|
||||
def embed(self, text):
|
||||
words = text.split()[:16]
|
||||
embedding = [len(word) for word in words]
|
||||
# Pad with 0 up to 16 words
|
||||
embedding += [0] * (16 - len(embedding))
|
||||
return embedding
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def register_embed_demo_model():
|
||||
class EmbedDemoPlugin:
|
||||
__name__ = "EmbedDemoPlugin"
|
||||
|
||||
@llm.hookimpl
|
||||
def register_embedding_models(self, register):
|
||||
register(EmbedDemo())
|
||||
|
||||
pm.register(EmbedDemoPlugin(), name="undo-embed-demo-plugin")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pm.unregister(name="undo-embed-demo-plugin")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_openai(requests_mock):
|
||||
return requests_mock.post(
|
||||
|
|
|
|||
6
tests/test_embed.py
Normal file
6
tests/test_embed.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import llm
|
||||
|
||||
|
||||
def test_demo_plugin():
|
||||
model = llm.get_embedding_model("embed-demo")
|
||||
assert model.embed("hello world") == [5, 5] + [0] * 14
|
||||
105
tests/test_embed_cli.py
Normal file
105
tests/test_embed_cli.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from click.testing import CliRunner
|
||||
from llm.cli import cli
|
||||
import json
|
||||
import pytest
|
||||
import sqlite_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"format_,expected",
|
||||
(
|
||||
("json", "[5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"),
|
||||
(
|
||||
"base64",
|
||||
(
|
||||
"AACgQAAAoEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
||||
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\n"
|
||||
),
|
||||
),
|
||||
(
|
||||
"hex",
|
||||
(
|
||||
"0000a0400000a04000000000000000000000000000000000000000000"
|
||||
"000000000000000000000000000000000000000000000000000000000"
|
||||
"00000000000000\n"
|
||||
),
|
||||
),
|
||||
(
|
||||
"blob",
|
||||
(
|
||||
b"\x00\x00\xef\xbf\xbd@\x00\x00\xef\xbf\xbd@\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n"
|
||||
).decode("utf-8"),
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_embed_output_format(format_, expected):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, ["embed", "--format", format_, "-c", "hello world", "-m", "embed-demo"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected_error",
|
||||
((["-c", "Content", "stories"], "Must provide both collection and id"),),
|
||||
)
|
||||
def test_embed_errors(args, expected_error):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["embed"] + args)
|
||||
assert result.exit_code == 1
|
||||
assert expected_error in result.output
|
||||
|
||||
|
||||
def test_embed_store(user_path):
|
||||
embeddings_db = user_path / "embeddings.db"
|
||||
assert not embeddings_db.exists()
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["embed", "-c", "hello", "-m", "embed-demo"])
|
||||
assert result.exit_code == 0
|
||||
# Should not have created the table
|
||||
assert not embeddings_db.exists()
|
||||
# Now run it to store
|
||||
result = runner.invoke(
|
||||
cli, ["embed", "-c", "hello", "-m", "embed-demo", "items", "1"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert embeddings_db.exists()
|
||||
# Check the contents
|
||||
db = sqlite_utils.Database(str(embeddings_db))
|
||||
assert list(db["collections"].rows) == [
|
||||
{"id": 1, "name": "items", "model": "embed-demo"}
|
||||
]
|
||||
assert list(db["embeddings"].rows) == [
|
||||
{
|
||||
"collection_id": 1,
|
||||
"id": "1",
|
||||
"embedding": (
|
||||
b"\x00\x00\xa0@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
b"\x00\x00\x00\x00\x00\x00\x00"
|
||||
),
|
||||
"content": None,
|
||||
"metadata": None,
|
||||
}
|
||||
]
|
||||
# Should show up in 'llm embed-db collections'
|
||||
for is_json in (False, True):
|
||||
args = ["embed-db", "collections"]
|
||||
if is_json:
|
||||
args.extend(["--json"])
|
||||
result2 = runner.invoke(cli, args)
|
||||
assert result2.exit_code == 0
|
||||
if is_json:
|
||||
assert json.loads(result2.output) == [
|
||||
{"name": "items", "model": "embed-demo", "num_embeddings": 1}
|
||||
]
|
||||
else:
|
||||
assert result2.output == "items: embed-demo\n 1 embedding\n"
|
||||
Loading…
Reference in a new issue