Don't require input if template does not use $input, closes #835

This commit is contained in:
Simon Willison 2025-03-15 19:17:24 -07:00
parent bc692e1f19
commit bfbcc201b7
3 changed files with 24 additions and 4 deletions

View file

@ -397,9 +397,11 @@ def prompt(
extract_last = template_obj.extract_last extract_last = template_obj.extract_last
if template_obj.schema_object: if template_obj.schema_object:
schema = template_obj.schema_object schema = template_obj.schema_object
prompt = read_prompt() input_ = ""
if "input" in template_obj.vars():
input_ = read_prompt()
try: try:
prompt, system = template_obj.evaluate(prompt, params) prompt, system = template_obj.evaluate(input_, params)
except Template.MissingVariables as ex: except Template.MissingVariables as ex:
raise click.ClickException(str(ex)) raise click.ClickException(str(ex))
if model_id is None and template_obj.model: if model_id is None and template_obj.model:

View file

@ -38,6 +38,14 @@ class Template(BaseModel):
system = self.interpolate(self.system, params) system = self.interpolate(self.system, params)
return prompt, system return prompt, system
def vars(self) -> set:
all_vars = set()
for text in [self.prompt, self.system]:
if not text:
continue
all_vars.update(self.extract_vars(string.Template(text)))
return all_vars
@classmethod @classmethod
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]: def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
if not text: if not text:

View file

@ -146,10 +146,11 @@ def test_templates_error_on_missing_schema(templates_path):
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"}) @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
@pytest.mark.parametrize( @pytest.mark.parametrize(
"template,extra_args,expected_model,expected_input,expected_error", "template,input_text,extra_args,expected_model,expected_input,expected_error",
( (
( (
"'Summarize this: $input'", "'Summarize this: $input'",
"Input text",
[], [],
"gpt-4o-mini", "gpt-4o-mini",
"Summarize this: Input text", "Summarize this: Input text",
@ -157,6 +158,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
( (
"prompt: 'Summarize this: $input'\nmodel: gpt-4", "prompt: 'Summarize this: $input'\nmodel: gpt-4",
"Input text",
[], [],
"gpt-4", "gpt-4",
"Summarize this: Input text", "Summarize this: Input text",
@ -164,6 +166,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
( (
"prompt: 'Summarize this: $input'", "prompt: 'Summarize this: $input'",
"Input text",
["-m", "4"], ["-m", "4"],
"gpt-4", "gpt-4",
"Summarize this: Input text", "Summarize this: Input text",
@ -171,6 +174,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
pytest.param( pytest.param(
"boo", "boo",
"Input text",
["-s", "s"], ["-s", "s"],
None, None,
None, None,
@ -179,6 +183,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
pytest.param( pytest.param(
"prompt: 'Say $hello'", "prompt: 'Say $hello'",
"Input text",
[], [],
None, None,
None, None,
@ -187,6 +192,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
( (
"prompt: 'Say $hello'", "prompt: 'Say $hello'",
"Input text",
["-p", "hello", "Blah"], ["-p", "hello", "Blah"],
"gpt-4o-mini", "gpt-4o-mini",
"Say Blah", "Say Blah",
@ -194,6 +200,7 @@ def test_templates_error_on_missing_schema(templates_path):
), ),
( (
"prompt: 'Say pelican'", "prompt: 'Say pelican'",
"",
[], [],
"gpt-4o-mini", "gpt-4o-mini",
"Say pelican", "Say pelican",
@ -205,6 +212,7 @@ def test_template_basic(
templates_path, templates_path,
mocked_openai_chat, mocked_openai_chat,
template, template,
input_text,
extra_args, extra_args,
expected_model, expected_model,
expected_input, expected_input,
@ -214,7 +222,9 @@ def test_template_basic(
runner = CliRunner() runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,
["--no-stream", "-t", "template", "Input text"] + extra_args, ["--no-stream", "-t", "template"]
+ ([input_text] if input_text else [])
+ extra_args,
catch_exceptions=False, catch_exceptions=False,
) )
if expected_error is None: if expected_error is None: