Renamed template.execute() to template.evaluate() and added type hints

This commit is contained in:
Simon Willison 2023-07-10 08:27:28 -07:00
parent 199f7e0767
commit a421aab7f0
3 changed files with 13 additions and 9 deletions

View file

@ -171,7 +171,7 @@ def prompt(
template_obj = load_template(template)
prompt = read_prompt()
try:
prompt, system = template_obj.execute(prompt, params)
prompt, system = template_obj.evaluate(prompt, params)
except Template.MissingVariables as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:

View file

@ -1,6 +1,6 @@
from pydantic import ConfigDict, BaseModel
import string
from typing import Optional
from typing import Optional, Any, Dict, List, Tuple
class Template(BaseModel):
@ -8,19 +8,23 @@ class Template(BaseModel):
prompt: Optional[str] = None
system: Optional[str] = None
model: Optional[str] = None
defaults: Optional[dict] = None
defaults: Optional[Dict[str, Any]] = None
model_config = ConfigDict(extra="forbid")
class MissingVariables(Exception):
pass
def execute(self, input, params=None):
def evaluate(
self, input: str, params: Optional[Dict[str, Any]] = None
) -> Tuple[Optional[str], Optional[str]]:
params = params or {}
params["input"] = input
if self.defaults:
for k, v in self.defaults.items():
if k not in params:
params[k] = v
prompt: Optional[str] = None
system: Optional[str] = None
if not self.prompt:
system = self.interpolate(self.system, params)
prompt = input
@ -30,7 +34,7 @@ class Template(BaseModel):
return prompt, system
@classmethod
def interpolate(cls, text, params):
def interpolate(cls, text: Optional[str], params: Dict[str, Any]) -> Optional[str]:
if not text:
return text
# Confirm all variables in text are provided
@ -44,7 +48,7 @@ class Template(BaseModel):
return string_template.substitute(**params)
@staticmethod
def extract_vars(string_template):
def extract_vars(string_template: string.Template) -> List[str]:
return [
match.group("named")
for match in string_template.pattern.finditer(string_template.template)

View file

@ -27,16 +27,16 @@ import yaml
),
),
)
def test_template_execute(
def test_template_evaluate(
prompt, system, defaults, params, expected_prompt, expected_system, expected_error
):
t = Template(name="t", prompt=prompt, system=system, defaults=defaults)
if expected_error:
with pytest.raises(Template.MissingVariables) as ex:
prompt, system = t.execute("input", params)
prompt, system = t.evaluate("input", params)
assert ex.value.args[0] == expected_error
else:
prompt, system = t.execute("input", params)
prompt, system = t.evaluate("input", params)
assert prompt == expected_prompt
assert system == expected_system