mirror of
https://github.com/Hopiu/llm.git
synced 2026-05-21 03:51:52 +00:00
Skip invalid templates in llm templates list, closes #880
This commit is contained in:
parent
70e0799821
commit
cc94111892
3 changed files with 70 additions and 47 deletions
107
llm/cli.py
107
llm/cli.py
|
|
@ -435,7 +435,10 @@ def prompt(
|
|||
# Cannot be used with system
|
||||
if system:
|
||||
raise click.ClickException("Cannot use -t/--template and --system together")
|
||||
template_obj = load_template(template)
|
||||
try:
|
||||
template_obj = load_template(template)
|
||||
except LoadTemplateError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
extract = template_obj.extract
|
||||
extract_last = template_obj.extract_last
|
||||
if template_obj.schema_object:
|
||||
|
|
@ -683,7 +686,10 @@ def chat(
|
|||
# Cannot be used with system
|
||||
if system:
|
||||
raise click.ClickException("Cannot use -t/--template and --system together")
|
||||
template_obj = load_template(template)
|
||||
try:
|
||||
template_obj = load_template(template)
|
||||
except LoadTemplateError as ex:
|
||||
raise click.ClickException(str(ex))
|
||||
if model_id is None and template_obj.model:
|
||||
model_id = template_obj.model
|
||||
|
||||
|
|
@ -1500,7 +1506,12 @@ def templates_list():
|
|||
pairs = []
|
||||
for file in path.glob("*.yaml"):
|
||||
name = file.stem
|
||||
template = load_template(name)
|
||||
garg = []
|
||||
try:
|
||||
template = load_template(name)
|
||||
except LoadTemplateError:
|
||||
# Skip invalid templates
|
||||
continue
|
||||
text = []
|
||||
if template.system:
|
||||
text.append(f"system: {template.system}")
|
||||
|
|
@ -2572,48 +2583,6 @@ def logs_db_path():
|
|||
return user_dir() / "logs.db"
|
||||
|
||||
|
||||
def _parse_yaml_template(name, content):
|
||||
try:
|
||||
loaded = yaml.safe_load(content)
|
||||
except yaml.YAMLError as ex:
|
||||
raise click.ClickException("Invalid YAML: {}".format(str(ex)))
|
||||
if isinstance(loaded, str):
|
||||
return Template(name=name, prompt=loaded)
|
||||
loaded["name"] = name
|
||||
try:
|
||||
return Template(**loaded)
|
||||
except pydantic.ValidationError as ex:
|
||||
msg = "A validation error occurred:\n"
|
||||
msg += render_errors(ex.errors())
|
||||
raise click.ClickException(msg)
|
||||
|
||||
|
||||
def load_template(name):
|
||||
if name.startswith("https://") or name.startswith("http://"):
|
||||
response = httpx.get(name)
|
||||
response.raise_for_status()
|
||||
return _parse_yaml_template(name, response.text)
|
||||
|
||||
if ":" in name:
|
||||
prefix, rest = name.split(":", 1)
|
||||
loaders = get_template_loaders()
|
||||
if prefix not in loaders:
|
||||
raise click.ClickException("Unknown template prefix: {}".format(prefix))
|
||||
loader = loaders[prefix]
|
||||
try:
|
||||
return loader(rest)
|
||||
except Exception as ex:
|
||||
raise click.ClickException(
|
||||
"Could not load template {}: {}".format(name, ex)
|
||||
)
|
||||
|
||||
path = template_dir() / f"{name}.yaml"
|
||||
if not path.exists():
|
||||
raise click.ClickException(f"Invalid template: {name}")
|
||||
content = path.read_text()
|
||||
return _parse_yaml_template(name, content)
|
||||
|
||||
|
||||
def get_history(chat_id):
|
||||
if chat_id is None:
|
||||
return None, []
|
||||
|
|
@ -2757,3 +2726,51 @@ def clear_model_option(model_id: str, key: str) -> None:
|
|||
del options[model_id]
|
||||
|
||||
path.write_text(json.dumps(options, indent=2))
|
||||
|
||||
|
||||
class LoadTemplateError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _parse_yaml_template(name, content):
|
||||
try:
|
||||
loaded = yaml.safe_load(content)
|
||||
except yaml.YAMLError as ex:
|
||||
raise LoadTemplateError("Invalid YAML: {}".format(str(ex)))
|
||||
if isinstance(loaded, str):
|
||||
return Template(name=name, prompt=loaded)
|
||||
loaded["name"] = name
|
||||
try:
|
||||
return Template(**loaded)
|
||||
except pydantic.ValidationError as ex:
|
||||
msg = "A validation error occurred:\n"
|
||||
msg += render_errors(ex.errors())
|
||||
raise LoadTemplateError(msg)
|
||||
|
||||
|
||||
def load_template(name: str) -> Template:
|
||||
"Or raises LoadTemplateError(msg)"
|
||||
if name.startswith("https://") or name.startswith("http://"):
|
||||
response = httpx.get(name)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as ex:
|
||||
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
|
||||
return _parse_yaml_template(name, response.text)
|
||||
|
||||
if ":" in name:
|
||||
prefix, rest = name.split(":", 1)
|
||||
loaders = get_template_loaders()
|
||||
if prefix not in loaders:
|
||||
raise LoadTemplateError("Unknown template prefix: {}".format(prefix))
|
||||
loader = loaders[prefix]
|
||||
try:
|
||||
return loader(rest)
|
||||
except Exception as ex:
|
||||
raise LoadTemplateError("Could not load template {}: {}".format(name, ex))
|
||||
|
||||
path = template_dir() / f"{name}.yaml"
|
||||
if not path.exists():
|
||||
raise LoadTemplateError(f"Invalid template: {name}")
|
||||
content = path.read_text()
|
||||
return _parse_yaml_template(name, content)
|
||||
|
|
|
|||
|
|
@ -240,8 +240,13 @@ def resolve_schema_input(db, schema_input, load_template):
|
|||
return
|
||||
if schema_input.strip().startswith("t:"):
|
||||
name = schema_input.strip()[2:]
|
||||
template = load_template(name)
|
||||
if not template.schema_object:
|
||||
schema_object = None
|
||||
try:
|
||||
template = load_template(name)
|
||||
schema_object = template.schema_object
|
||||
except ValueError:
|
||||
raise click.ClickException("Invalid template: {}".format(name))
|
||||
if not schema_object:
|
||||
raise click.ClickException("Template '{}' has no schema".format(name))
|
||||
return template.schema_object
|
||||
if schema_input.strip().startswith("{"):
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ def test_templates_list(templates_path, args):
|
|||
"system: summarize this\nprompt: $input", "utf-8"
|
||||
)
|
||||
(templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8")
|
||||
(templates_path / "invalid.yaml").write_text("system2: This is invalid", "utf-8")
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, args)
|
||||
assert result.exit_code == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue