diff --git a/docs/templates.md b/docs/templates.md index 76d5984..2d6cd43 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -28,7 +28,10 @@ You can also save default parameters: llm --system 'Summarize this text in the voice of $voice' \ --model gpt-4 -p voice GlaDOS --save summarize ``` - +Or options: +```bash +llm --system 'Speak in French' -o temperature 1.8 --save wild-french +``` Add `--schema` to bake a {ref}`schema ` into your template: ```bash @@ -143,6 +146,20 @@ system: You speak like an excitable Victorian adventurer prompt: 'Summarize this: $input' ``` + +(prompt-templates-options)= + +### Options + +Default options can be set using the `options:` key: + +```yaml +name: wild-french +system: Speak in French +options: + temperature: 1.8 +``` + (prompt-templates-schemas)= ### Schemas diff --git a/llm/cli.py b/llm/cli.py index 2727a02..3ddbcdf 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -397,6 +397,17 @@ def prompt( to_save["extract_last"] = True if schema: to_save["schema_object"] = schema + if options: + # Need to validate and convert their types first + model = get_model(model_id or get_default_model()) + try: + to_save["options"] = dict( + (key, value) + for key, value in model.Options(**dict(options)) + if value is not None + ) + except pydantic.ValidationError as ex: + raise click.ClickException(render_errors(ex.errors())) path.write_text( yaml.dump( to_save, @@ -419,10 +430,21 @@ def prompt( if template_obj.schema_object: schema = template_obj.schema_object input_ = "" + if template_obj.options: + # Make options mutable (they start as a tuple) + options = list(options) + # Load any options, provided they were not set using -o already + specified_options = dict(options) + for option_name, option_value in template_obj.options.items(): + if option_name not in specified_options: + options.append((option_name, option_value)) if "input" in template_obj.vars(): input_ = read_prompt() try: - prompt, system = template_obj.evaluate(input_, params) + template_prompt, system = template_obj.evaluate(input_, params) + if template_prompt: + # Over-ride user prompt only if the template provided one + prompt = template_prompt except Template.MissingVariables as ex: raise click.ClickException(str(ex)) if model_id is None and template_obj.model: diff --git a/llm/templates.py b/llm/templates.py index 94a62c3..9d3495e 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -9,6 +9,7 @@ class Template(BaseModel): system: Optional[str] = None model: Optional[str] = None defaults: Optional[Dict[str, Any]] = None + options: Optional[Dict[str, Any]] = None # Should a fenced code block be extracted? extract: Optional[bool] = None extract_last: Optional[bool] = None diff --git a/tests/test_templates.py b/tests/test_templates.py index 957bd77..33163db 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -91,6 +91,12 @@ def test_templates_list(templates_path, args): {"prompt": "Say hello as $name", "defaults": {"name": "default-name"}}, None, ), + # Options + ( + ["-o", "temperature", "0.5", "--system", "in french"], + {"system": "in french", "options": {"temperature": 0.5}}, + None, + ), # -x/--extract should be persisted: ( ["--system", "write python", "--extract"], @@ -146,7 +152,7 @@ def test_templates_error_on_missing_schema(templates_path): @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize( - "template,input_text,extra_args,expected_model,expected_input,expected_error", + "template,input_text,extra_args,expected_model,expected_input,expected_error,expected_options", ( ( "'Summarize this: $input'", @@ -155,6 +161,7 @@ def test_templates_error_on_missing_schema(templates_path): "gpt-4o-mini", "Summarize this: Input text", None, + None, ), ( "prompt: 'Summarize this: $input'\nmodel: gpt-4", @@ -163,6 +170,7 @@ def test_templates_error_on_missing_schema(templates_path): "gpt-4", "Summarize this: Input text", None, + None, ), ( "prompt: 'Summarize this: $input'", @@ -171,6 +179,7 @@ def test_templates_error_on_missing_schema(templates_path): "gpt-4", "Summarize this: Input text", None, + None, ), pytest.param( "boo", @@ -179,6 +188,7 @@ def test_templates_error_on_missing_schema(templates_path): None, None, "Error: Cannot use -t/--template and --system together", + None, marks=pytest.mark.httpx_mock(), ), pytest.param( @@ -188,6 +198,7 @@ def test_templates_error_on_missing_schema(templates_path): None, None, "Error: Missing variables: hello", + None, marks=pytest.mark.httpx_mock(), ), ( @@ -197,6 +208,7 @@ def test_templates_error_on_missing_schema(templates_path): "gpt-4o-mini", "Say Blah", None, + None, ), ( "prompt: 'Say pelican'", @@ -205,10 +217,44 @@ def test_templates_error_on_missing_schema(templates_path): "gpt-4o-mini", "Say pelican", None, + None, + ), + # Template with just a system prompt + ( + "system: 'Summarize this'", + "Input text", + [], + "gpt-4o-mini", + [ + {"content": "Summarize this", "role": "system"}, + {"content": "Input text", "role": "user"}, + ], + None, + None, + ), + # Options + ( + "prompt: 'Summarize this: $input'\noptions:\n temperature: 0.5", + "Input text", + [], + "gpt-4o-mini", + "Summarize this: Input text", + None, + {"temperature": 0.5}, + ), + # Should be over-ridden by CLI + ( + "prompt: 'Summarize this: $input'\noptions:\n temperature: 0.5", + "Input text", + ["-o", "temperature", "0.7"], + "gpt-4o-mini", + "Summarize this: Input text", + None, + {"temperature": 0.7}, ), ), ) -def test_template_basic( +def test_execute_prompt_with_a_template( templates_path, mocked_openai_chat, template, @@ -217,6 +263,7 @@ def test_template_basic( expected_model, expected_input, expected_error, + expected_options, ): (templates_path / "template.yaml").write_text(template, "utf-8") runner = CliRunner() @@ -227,14 +274,22 @@ def test_template_basic( + extra_args, catch_exceptions=False, ) + if isinstance(expected_input, str): + expected_messages = [{"role": "user", "content": expected_input}] + else: + expected_messages = expected_input + if expected_error is None: assert result.exit_code == 0 last_request = mocked_openai_chat.get_requests()[-1] - assert json.loads(last_request.content) == { + expected_data = { "model": expected_model, - "messages": [{"role": "user", "content": expected_input}], + "messages": expected_messages, "stream": False, } + if expected_options: + expected_data.update(expected_options) + assert json.loads(last_request.content) == expected_data else: assert result.exit_code == 1 assert result.output.strip() == expected_error