mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-17 05:00:25 +00:00
Don't require input if template does not use $input, closes #835
This commit is contained in:
parent
bc692e1f19
commit
bfbcc201b7
3 changed files with 24 additions and 4 deletions
|
|
@ -397,9 +397,11 @@ def prompt(
|
|||
extract_last = template_obj.extract_last
|
||||
if template_obj.schema_object:
|
||||
schema = template_obj.schema_object
|
||||
prompt = read_prompt()
|
||||
input_ = ""
|
||||
if "input" in template_obj.vars():
|
||||
input_ = read_prompt()
|
||||
try:
|
||||
prompt, system = template_obj.evaluate(prompt, params)
|
||||
prompt, system = template_obj.evaluate(input_, params)
|
||||
except Template.MissingVariables as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
if model_id is None and template_obj.model:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,14 @@ class Template(BaseModel):
|
|||
system = self.interpolate(self.system, params)
|
||||
return prompt, system
|
||||
|
||||
def vars(self) -> set:
|
||||
all_vars = set()
|
||||
for text in [self.prompt, self.system]:
|
||||
if not text:
|
||||
continue
|
||||
all_vars.update(self.extract_vars(string.Template(text)))
|
||||
return all_vars
|
||||
|
||||
@classmethod
|
||||
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
|
||||
if not text:
|
||||
|
|
|
|||
|
|
@ -146,10 +146,11 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
|
||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
||||
@pytest.mark.parametrize(
|
||||
"template,extra_args,expected_model,expected_input,expected_error",
|
||||
"template,input_text,extra_args,expected_model,expected_input,expected_error",
|
||||
(
|
||||
(
|
||||
"'Summarize this: $input'",
|
||||
"Input text",
|
||||
[],
|
||||
"gpt-4o-mini",
|
||||
"Summarize this: Input text",
|
||||
|
|
@ -157,6 +158,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
(
|
||||
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
|
||||
"Input text",
|
||||
[],
|
||||
"gpt-4",
|
||||
"Summarize this: Input text",
|
||||
|
|
@ -164,6 +166,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
(
|
||||
"prompt: 'Summarize this: $input'",
|
||||
"Input text",
|
||||
["-m", "4"],
|
||||
"gpt-4",
|
||||
"Summarize this: Input text",
|
||||
|
|
@ -171,6 +174,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
pytest.param(
|
||||
"boo",
|
||||
"Input text",
|
||||
["-s", "s"],
|
||||
None,
|
||||
None,
|
||||
|
|
@ -179,6 +183,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
pytest.param(
|
||||
"prompt: 'Say $hello'",
|
||||
"Input text",
|
||||
[],
|
||||
None,
|
||||
None,
|
||||
|
|
@ -187,6 +192,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
(
|
||||
"prompt: 'Say $hello'",
|
||||
"Input text",
|
||||
["-p", "hello", "Blah"],
|
||||
"gpt-4o-mini",
|
||||
"Say Blah",
|
||||
|
|
@ -194,6 +200,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
|||
),
|
||||
(
|
||||
"prompt: 'Say pelican'",
|
||||
"",
|
||||
[],
|
||||
"gpt-4o-mini",
|
||||
"Say pelican",
|
||||
|
|
@ -205,6 +212,7 @@ def test_template_basic(
|
|||
templates_path,
|
||||
mocked_openai_chat,
|
||||
template,
|
||||
input_text,
|
||||
extra_args,
|
||||
expected_model,
|
||||
expected_input,
|
||||
|
|
@ -214,7 +222,9 @@ def test_template_basic(
|
|||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["--no-stream", "-t", "template", "Input text"] + extra_args,
|
||||
["--no-stream", "-t", "template"]
|
||||
+ ([input_text] if input_text else [])
|
||||
+ extra_args,
|
||||
catch_exceptions=False,
|
||||
)
|
||||
if expected_error is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue