From bfbcc201b72a0a406ff8dffadd6d8dd4bf717015 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 15 Mar 2025 19:17:24 -0700 Subject: [PATCH] Don't require input if template does not use $input, closes #835 --- llm/cli.py | 6 ++++-- llm/templates.py | 8 ++++++++ tests/test_templates.py | 14 ++++++++++++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/llm/cli.py b/llm/cli.py index 20bbdee..fb04dcd 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -397,9 +397,11 @@ def prompt( extract_last = template_obj.extract_last if template_obj.schema_object: schema = template_obj.schema_object - prompt = read_prompt() + input_ = "" + if "input" in template_obj.vars(): + input_ = read_prompt() try: - prompt, system = template_obj.evaluate(prompt, params) + prompt, system = template_obj.evaluate(input_, params) except Template.MissingVariables as ex: raise click.ClickException(str(ex)) if model_id is None and template_obj.model: diff --git a/llm/templates.py b/llm/templates.py index 502007f..94a62c3 100644 --- a/llm/templates.py +++ b/llm/templates.py @@ -38,6 +38,14 @@ class Template(BaseModel): system = self.interpolate(self.system, params) return prompt, system + def vars(self) -> set: + all_vars = set() + for text in [self.prompt, self.system]: + if not text: + continue + all_vars.update(self.extract_vars(string.Template(text))) + return all_vars + @classmethod def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]: if not text: diff --git a/tests/test_templates.py b/tests/test_templates.py index ba3fe4a..957bd77 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -146,10 +146,11 @@ def test_templates_error_on_missing_schema(templates_path): @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @pytest.mark.parametrize( - "template,extra_args,expected_model,expected_input,expected_error", + "template,input_text,extra_args,expected_model,expected_input,expected_error", ( ( "'Summarize this: $input'", + "Input text", [], "gpt-4o-mini", "Summarize this: Input text", @@ -157,6 +158,7 @@ def test_templates_error_on_missing_schema(templates_path): ), ( "prompt: 'Summarize this: $input'\nmodel: gpt-4", + "Input text", [], "gpt-4", "Summarize this: Input text", @@ -164,6 +166,7 @@ def test_templates_error_on_missing_schema(templates_path): ), ( "prompt: 'Summarize this: $input'", + "Input text", ["-m", "4"], "gpt-4", "Summarize this: Input text", @@ -171,6 +174,7 @@ def test_templates_error_on_missing_schema(templates_path): ), pytest.param( "boo", + "Input text", ["-s", "s"], None, None, @@ -179,6 +183,7 @@ def test_templates_error_on_missing_schema(templates_path): ), pytest.param( "prompt: 'Say $hello'", + "Input text", [], None, None, @@ -187,6 +192,7 @@ def test_templates_error_on_missing_schema(templates_path): ), ( "prompt: 'Say $hello'", + "Input text", ["-p", "hello", "Blah"], "gpt-4o-mini", "Say Blah", @@ -194,6 +200,7 @@ def test_templates_error_on_missing_schema(templates_path): ), ( "prompt: 'Say pelican'", + "", [], "gpt-4o-mini", "Say pelican", @@ -205,6 +212,7 @@ def test_template_basic( templates_path, mocked_openai_chat, template, + input_text, extra_args, expected_model, expected_input, @@ -214,7 +222,9 @@ def test_template_basic( runner = CliRunner() result = runner.invoke( cli, - ["--no-stream", "-t", "template", "Input text"] + extra_args, + ["--no-stream", "-t", "template"] + + ([input_text] if input_text else []) + + extra_args, catch_exceptions=False, ) if expected_error is None: