Skip invalid templates in llm templates list, closes #880

This commit is contained in:
Simon Willison 2025-04-05 11:32:35 -07:00
parent 70e0799821
commit cc94111892
3 changed files with 70 additions and 47 deletions

View file

@ -435,7 +435,10 @@ def prompt(
# Cannot be used with system
if system:
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
try:
template_obj = load_template(template)
except LoadTemplateError as ex:
raise click.ClickException(str(ex))
extract = template_obj.extract
extract_last = template_obj.extract_last
if template_obj.schema_object:
@ -683,7 +686,10 @@ def chat(
# Cannot be used with system
if system:
raise click.ClickException("Cannot use -t/--template and --system together")
template_obj = load_template(template)
try:
template_obj = load_template(template)
except LoadTemplateError as ex:
raise click.ClickException(str(ex))
if model_id is None and template_obj.model:
model_id = template_obj.model
@ -1500,7 +1506,12 @@ def templates_list():
pairs = []
for file in path.glob("*.yaml"):
name = file.stem
template = load_template(name)
garg = []
try:
template = load_template(name)
except LoadTemplateError:
# Skip invalid templates
continue
text = []
if template.system:
text.append(f"system: {template.system}")
@ -2572,48 +2583,6 @@ def logs_db_path():
return user_dir() / "logs.db"
def _parse_yaml_template(name, content):
try:
loaded = yaml.safe_load(content)
except yaml.YAMLError as ex:
raise click.ClickException("Invalid YAML: {}".format(str(ex)))
if isinstance(loaded, str):
return Template(name=name, prompt=loaded)
loaded["name"] = name
try:
return Template(**loaded)
except pydantic.ValidationError as ex:
msg = "A validation error occurred:\n"
msg += render_errors(ex.errors())
raise click.ClickException(msg)
def load_template(name):
if name.startswith("https://") or name.startswith("http://"):
response = httpx.get(name)
response.raise_for_status()
return _parse_yaml_template(name, response.text)
if ":" in name:
prefix, rest = name.split(":", 1)
loaders = get_template_loaders()
if prefix not in loaders:
raise click.ClickException("Unknown template prefix: {}".format(prefix))
loader = loaders[prefix]
try:
return loader(rest)
except Exception as ex:
raise click.ClickException(
"Could not load template {}: {}".format(name, ex)
)
path = template_dir() / f"{name}.yaml"
if not path.exists():
raise click.ClickException(f"Invalid template: {name}")
content = path.read_text()
return _parse_yaml_template(name, content)
def get_history(chat_id):
if chat_id is None:
return None, []
@ -2757,3 +2726,51 @@ def clear_model_option(model_id: str, key: str) -> None:
del options[model_id]
path.write_text(json.dumps(options, indent=2))
class LoadTemplateError(ValueError):
pass
def _parse_yaml_template(name, content):
try:
loaded = yaml.safe_load(content)
except yaml.YAMLError as ex:
raise LoadTemplateError("Invalid YAML: {}".format(str(ex)))
if isinstance(loaded, str):
return Template(name=name, prompt=loaded)
loaded["name"] = name
try:
return Template(**loaded)
except pydantic.ValidationError as ex:
msg = "A validation error occurred:\n"
msg += render_errors(ex.errors())
raise LoadTemplateError(msg)
def load_template(name: str) -> Template:
"Or raises LoadTemplateError(msg)"
if name.startswith("https://") or name.startswith("http://"):
response = httpx.get(name)
try:
response.raise_for_status()
except httpx.HTTPStatusError as ex:
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
return _parse_yaml_template(name, response.text)
if ":" in name:
prefix, rest = name.split(":", 1)
loaders = get_template_loaders()
if prefix not in loaders:
raise LoadTemplateError("Unknown template prefix: {}".format(prefix))
loader = loaders[prefix]
try:
return loader(rest)
except Exception as ex:
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
path = template_dir() / f"{name}.yaml"
if not path.exists():
raise LoadTemplateError(f"Invalid template: {name}")
content = path.read_text()
return _parse_yaml_template(name, content)

View file

@ -240,8 +240,13 @@ def resolve_schema_input(db, schema_input, load_template):
return
if schema_input.strip().startswith("t:"):
name = schema_input.strip()[2:]
template = load_template(name)
if not template.schema_object:
schema_object = None
try:
template = load_template(name)
schema_object = template.schema_object
except ValueError:
raise click.ClickException("Invalid template: {}".format(name))
if not schema_object:
raise click.ClickException("Template '{}' has no schema".format(name))
return template.schema_object
if schema_input.strip().startswith("{"):

View file

@ -63,6 +63,7 @@ def test_templates_list(templates_path, args):
"system: summarize this\nprompt: $input", "utf-8"
)
(templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8")
(templates_path / "invalid.yaml").write_text("system2: This is invalid", "utf-8")
runner = CliRunner()
result = runner.invoke(cli, args)
assert result.exit_code == 0