Don't require input if template does not use $input, closes #835

This commit is contained in:
Simon Willison 2025-03-15 19:17:24 -07:00
parent bc692e1f19
commit bfbcc201b7
3 changed files with 24 additions and 4 deletions

View file

@ -397,9 +397,11 @@ def prompt(
extract_last = template_obj.extract_last
if template_obj.schema_object:
schema = template_obj.schema_object
prompt = read_prompt()
input_ = ""
if "input" in template_obj.vars():
input_ = read_prompt()
try:
prompt, system = template_obj.evaluate(prompt, params)
prompt, system = template_obj.evaluate(input_, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:

View file

@ -38,6 +38,14 @@ class Template(BaseModel):
system = self.interpolate(self.system, params)
return prompt, system
def vars(self) -> set:
all_vars = set()
for text in [self.prompt, self.system]:
if not text:
continue
all_vars.update(self.extract_vars(string.Template(text)))
return all_vars
@classmethod
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
if not text:

View file

@ -146,10 +146,11 @@ def test_templates_error_on_missing_schema(templates_path):
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize(
"template,extra_args,expected_model,expected_input,expected_error",
"template,input_text,extra_args,expected_model,expected_input,expected_error",
(
(
"'Summarize this: $input'",
"Input text",
[],
"gpt-4o-mini",
"Summarize this: Input text",
@ -157,6 +158,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
(
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
"Input text",
[],
"gpt-4",
"Summarize this: Input text",
@ -164,6 +166,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
(
"prompt: 'Summarize this: $input'",
"Input text",
["-m", "4"],
"gpt-4",
"Summarize this: Input text",
@ -171,6 +174,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
pytest.param(
"boo",
"Input text",
["-s", "s"],
None,
None,
@ -179,6 +183,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
pytest.param(
"prompt: 'Say $hello'",
"Input text",
[],
None,
None,
@ -187,6 +192,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
(
"prompt: 'Say $hello'",
"Input text",
["-p", "hello", "Blah"],
"gpt-4o-mini",
"Say Blah",
@ -194,6 +200,7 @@ def test_templates_error_on_missing_schema(templates_path):
),
(
"prompt: 'Say pelican'",
"",
[],
"gpt-4o-mini",
"Say pelican",
@ -205,6 +212,7 @@ def test_template_basic(
templates_path,
mocked_openai_chat,
template,
input_text,
extra_args,
expected_model,
expected_input,
@ -214,7 +222,9 @@ def test_template_basic(
runner = CliRunner()
result = runner.invoke(
cli,
["--no-stream", "-t", "template", "Input text"] + extra_args,
["--no-stream", "-t", "template"]
+ ([input_text] if input_text else [])
+ extra_args,
catch_exceptions=False,
)
if expected_error is None: