diff --git a/docs/help.md b/docs/help.md index e0d8513..2a3221d 100644 --- a/docs/help.md +++ b/docs/help.md @@ -186,6 +186,7 @@ Usage: llm logs list [OPTIONS] Options: -n, --count INTEGER Number of entries to show - 0 for all -p, --path FILE Path to log database + -m, --model TEXT Filter by model or model alias -t, --truncate Truncate long strings in output --help Show this message and exit. ``` diff --git a/docs/logging.md b/docs/logging.md index 28c1182..d91c0ad 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -62,6 +62,10 @@ Or `-n 0` to see everything that has ever been logged: ```bash llm logs -n 0 ``` +You can filter to logs just for a specific model (or model alias) using `-m/--model`: +```bash +llm logs -m chatgpt +``` You can truncate the display of the prompts and responses using the `-t/--truncate` option: ```bash llm logs -n 5 -t diff --git a/llm/cli.py b/llm/cli.py index b2ed5c8..b40731a 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -385,14 +385,25 @@ def logs_turn_off(): type=click.Path(readable=True, exists=True, dir_okay=False), help="Path to log database", ) +@click.option("-m", "--model", help="Filter by model or model alias") @click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output") -def logs_list(count, path, truncate): +def logs_list(count, path, model, truncate): "Show recent logged prompts and their responses" path = pathlib.Path(path or logs_db_path()) if not path.exists(): raise click.ClickException("No log database found at {}".format(path)) db = sqlite_utils.Database(path) migrate(db) + + model_id = None + if model: + # Resolve alias, if any + try: + model_id = get_model(model).model_id + except UnknownModelError: + # Maybe they uninstalled a model, use the -m option as-is + model_id = model + rows = list( db.query( """ @@ -412,11 +423,13 @@ def logs_list(count, path, truncate): conversations.model as conversation_model from responses - left join conversations on responses.conversation_id = conversations.id - order by responses.id desc{} + left join conversations on responses.conversation_id = conversations.id{where} + order by responses.id desc{limit} """.format( - " limit {}".format(count) if count else "" - ) + where=" where responses.model = :model" if model_id else "", + limit=" limit {}".format(count) if count else "", + ), + {"model": model_id}, ) ) for row in rows: diff --git a/tests/test_llm.py b/tests/test_llm.py index 34c849b..6fc06b5 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -63,6 +63,28 @@ def test_logs_path(monkeypatch, env, user_path): assert result.output.strip() == expected +@pytest.mark.parametrize("model", ("davinci", "curie")) +def test_logs_filtered(user_path, model): + log_path = str(user_path / "logs.db") + db = sqlite_utils.Database(log_path) + migrate(db) + db["responses"].insert_all( + { + "id": str(ULID()).lower(), + "system": "system", + "prompt": "prompt", + "response": "response", + "model": "davinci" if i % 2 == 0 else "curie", + } + for i in range(100) + ) + runner = CliRunner() + result = runner.invoke(cli, ["logs", "list", "-m", model]) + assert result.exit_code == 0 + records = json.loads(result.output.strip()) + assert all(record["model"] == model for record in records) + + @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize("use_stdin", (True, False)) @pytest.mark.parametrize(