mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-04 11:54:46 +00:00
Renamed template.execute() to template.evaluate() and added type hints
This commit is contained in:
parent
199f7e0767
commit
a421aab7f0
3 changed files with 13 additions and 9 deletions
|
|
@ -171,7 +171,7 @@ def prompt(
|
|||
template_obj = load_template(template)
|
||||
prompt = read_prompt()
|
||||
try:
|
||||
prompt, system = template_obj.execute(prompt, params)
|
||||
prompt, system = template_obj.evaluate(prompt, params)
|
||||
except Template.MissingVariables as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
if model_id is None and template_obj.model:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from pydantic import ConfigDict, BaseModel
|
||||
import string
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict, List, Tuple
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
|
|
@ -8,19 +8,23 @@ class Template(BaseModel):
|
|||
prompt: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
defaults: Optional[dict] = None
|
||||
defaults: Optional[Dict[str, Any]] = None
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
class MissingVariables(Exception):
|
||||
pass
|
||||
|
||||
def execute(self, input, params=None):
|
||||
def evaluate(
|
||||
self, input: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
params = params or {}
|
||||
params["input"] = input
|
||||
if self.defaults:
|
||||
for k, v in self.defaults.items():
|
||||
if k not in params:
|
||||
params[k] = v
|
||||
prompt: Optional[str] = None
|
||||
system: Optional[str] = None
|
||||
if not self.prompt:
|
||||
system = self.interpolate(self.system, params)
|
||||
prompt = input
|
||||
|
|
@ -30,7 +34,7 @@ class Template(BaseModel):
|
|||
return prompt, system
|
||||
|
||||
@classmethod
|
||||
def interpolate(cls, text, params):
|
||||
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
|
||||
if not text:
|
||||
return text
|
||||
# Confirm all variables in text are provided
|
||||
|
|
@ -44,7 +48,7 @@ class Template(BaseModel):
|
|||
return string_template.substitute(**params)
|
||||
|
||||
@staticmethod
|
||||
def extract_vars(string_template):
|
||||
def extract_vars(string_template: string.Template) -> List[str]:
|
||||
return [
|
||||
match.group("named")
|
||||
for match in string_template.pattern.finditer(string_template.template)
|
||||
|
|
|
|||
|
|
@ -27,16 +27,16 @@ import yaml
|
|||
),
|
||||
),
|
||||
)
|
||||
def test_template_execute(
|
||||
def test_template_evaluate(
|
||||
prompt, system, defaults, params, expected_prompt, expected_system, expected_error
|
||||
):
|
||||
t = Template(name="t", prompt=prompt, system=system, defaults=defaults)
|
||||
if expected_error:
|
||||
with pytest.raises(Template.MissingVariables) as ex:
|
||||
prompt, system = t.execute("input", params)
|
||||
prompt, system = t.evaluate("input", params)
|
||||
assert ex.value.args[0] == expected_error
|
||||
else:
|
||||
prompt, system = t.execute("input", params)
|
||||
prompt, system = t.evaluate("input", params)
|
||||
assert prompt == expected_prompt
|
||||
assert system == expected_system
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue