mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-11 23:33:10 +00:00
Don't require input if template does not use $input, closes #835
This commit is contained in:
parent
bc692e1f19
commit
bfbcc201b7
3 changed files with 24 additions and 4 deletions
|
|
@ -397,9 +397,11 @@ def prompt(
|
||||||
extract_last = template_obj.extract_last
|
extract_last = template_obj.extract_last
|
||||||
if template_obj.schema_object:
|
if template_obj.schema_object:
|
||||||
schema = template_obj.schema_object
|
schema = template_obj.schema_object
|
||||||
prompt = read_prompt()
|
input_ = ""
|
||||||
|
if "input" in template_obj.vars():
|
||||||
|
input_ = read_prompt()
|
||||||
try:
|
try:
|
||||||
prompt, system = template_obj.evaluate(prompt, params)
|
prompt, system = template_obj.evaluate(input_, params)
|
||||||
except Template.MissingVariables as ex:
|
except Template.MissingVariables as ex:
|
||||||
raise click.ClickException(str(ex))
|
raise click.ClickException(str(ex))
|
||||||
if model_id is None and template_obj.model:
|
if model_id is None and template_obj.model:
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,14 @@ class Template(BaseModel):
|
||||||
system = self.interpolate(self.system, params)
|
system = self.interpolate(self.system, params)
|
||||||
return prompt, system
|
return prompt, system
|
||||||
|
|
||||||
|
def vars(self) -> set:
|
||||||
|
all_vars = set()
|
||||||
|
for text in [self.prompt, self.system]:
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
all_vars.update(self.extract_vars(string.Template(text)))
|
||||||
|
return all_vars
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
|
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
|
||||||
if not text:
|
if not text:
|
||||||
|
|
|
||||||
|
|
@ -146,10 +146,11 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "X"})
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"template,extra_args,expected_model,expected_input,expected_error",
|
"template,input_text,extra_args,expected_model,expected_input,expected_error",
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
"'Summarize this: $input'",
|
"'Summarize this: $input'",
|
||||||
|
"Input text",
|
||||||
[],
|
[],
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
"Summarize this: Input text",
|
"Summarize this: Input text",
|
||||||
|
|
@ -157,6 +158,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
|
"prompt: 'Summarize this: $input'\nmodel: gpt-4",
|
||||||
|
"Input text",
|
||||||
[],
|
[],
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"Summarize this: Input text",
|
"Summarize this: Input text",
|
||||||
|
|
@ -164,6 +166,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"prompt: 'Summarize this: $input'",
|
"prompt: 'Summarize this: $input'",
|
||||||
|
"Input text",
|
||||||
["-m", "4"],
|
["-m", "4"],
|
||||||
"gpt-4",
|
"gpt-4",
|
||||||
"Summarize this: Input text",
|
"Summarize this: Input text",
|
||||||
|
|
@ -171,6 +174,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"boo",
|
"boo",
|
||||||
|
"Input text",
|
||||||
["-s", "s"],
|
["-s", "s"],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
@ -179,6 +183,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"prompt: 'Say $hello'",
|
"prompt: 'Say $hello'",
|
||||||
|
"Input text",
|
||||||
[],
|
[],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
@ -187,6 +192,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"prompt: 'Say $hello'",
|
"prompt: 'Say $hello'",
|
||||||
|
"Input text",
|
||||||
["-p", "hello", "Blah"],
|
["-p", "hello", "Blah"],
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
"Say Blah",
|
"Say Blah",
|
||||||
|
|
@ -194,6 +200,7 @@ def test_templates_error_on_missing_schema(templates_path):
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"prompt: 'Say pelican'",
|
"prompt: 'Say pelican'",
|
||||||
|
"",
|
||||||
[],
|
[],
|
||||||
"gpt-4o-mini",
|
"gpt-4o-mini",
|
||||||
"Say pelican",
|
"Say pelican",
|
||||||
|
|
@ -205,6 +212,7 @@ def test_template_basic(
|
||||||
templates_path,
|
templates_path,
|
||||||
mocked_openai_chat,
|
mocked_openai_chat,
|
||||||
template,
|
template,
|
||||||
|
input_text,
|
||||||
extra_args,
|
extra_args,
|
||||||
expected_model,
|
expected_model,
|
||||||
expected_input,
|
expected_input,
|
||||||
|
|
@ -214,7 +222,9 @@ def test_template_basic(
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
["--no-stream", "-t", "template", "Input text"] + extra_args,
|
["--no-stream", "-t", "template"]
|
||||||
|
+ ([input_text] if input_text else [])
|
||||||
|
+ extra_args,
|
||||||
catch_exceptions=False,
|
catch_exceptions=False,
|
||||||
)
|
)
|
||||||
if expected_error is None:
|
if expected_error is None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue