diff --git a/llm/cli.py b/llm/cli.py index 91f8294..f9faaac 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -435,7 +435,10 @@ def prompt( # Cannot be used with system if system: raise click.ClickException("Cannot use -t/--template and --system together") - template_obj = load_template(template) + try: + template_obj = load_template(template) + except LoadTemplateError as ex: + raise click.ClickException(str(ex)) extract = template_obj.extract extract_last = template_obj.extract_last if template_obj.schema_object: @@ -683,7 +686,10 @@ def chat( # Cannot be used with system if system: raise click.ClickException("Cannot use -t/--template and --system together") - template_obj = load_template(template) + try: + template_obj = load_template(template) + except LoadTemplateError as ex: + raise click.ClickException(str(ex)) if model_id is None and template_obj.model: model_id = template_obj.model @@ -1500,7 +1506,12 @@ def templates_list(): pairs = [] for file in path.glob("*.yaml"): name = file.stem - template = load_template(name) + garg = [] + try: + template = load_template(name) + except LoadTemplateError: + # Skip invalid templates + continue text = [] if template.system: text.append(f"system: {template.system}") @@ -2572,48 +2583,6 @@ def logs_db_path(): return user_dir() / "logs.db" -def _parse_yaml_template(name, content): - try: - loaded = yaml.safe_load(content) - except yaml.YAMLError as ex: - raise click.ClickException("Invalid YAML: {}".format(str(ex))) - if isinstance(loaded, str): - return Template(name=name, prompt=loaded) - loaded["name"] = name - try: - return Template(**loaded) - except pydantic.ValidationError as ex: - msg = "A validation error occurred:\n" - msg += render_errors(ex.errors()) - raise click.ClickException(msg) - - -def load_template(name): - if name.startswith("https://") or name.startswith("http://"): - response = httpx.get(name) - response.raise_for_status() - return _parse_yaml_template(name, response.text) - - if ":" in name: - prefix, rest = name.split(":", 1) - loaders = get_template_loaders() - if prefix not in loaders: - raise click.ClickException("Unknown template prefix: {}".format(prefix)) - loader = loaders[prefix] - try: - return loader(rest) - except Exception as ex: - raise click.ClickException( - "Could not load template {}: {}".format(name, ex) - ) - - path = template_dir() / f"{name}.yaml" - if not path.exists(): - raise click.ClickException(f"Invalid template: {name}") - content = path.read_text() - return _parse_yaml_template(name, content) - - def get_history(chat_id): if chat_id is None: return None, [] @@ -2757,3 +2726,51 @@ def clear_model_option(model_id: str, key: str) -> None: del options[model_id] path.write_text(json.dumps(options, indent=2)) + + +class LoadTemplateError(ValueError): + pass + + +def _parse_yaml_template(name, content): + try: + loaded = yaml.safe_load(content) + except yaml.YAMLError as ex: + raise LoadTemplateError("Invalid YAML: {}".format(str(ex))) + if isinstance(loaded, str): + return Template(name=name, prompt=loaded) + loaded["name"] = name + try: + return Template(**loaded) + except pydantic.ValidationError as ex: + msg = "A validation error occurred:\n" + msg += render_errors(ex.errors()) + raise LoadTemplateError(msg) + + +def load_template(name: str) -> Template: + "Or raises LoadTemplateError(msg)" + if name.startswith("https://") or name.startswith("http://"): + response = httpx.get(name) + try: + response.raise_for_status() + except httpx.HTTPStatusError as ex: + raise LoadTemplateError("Could not load template {}: {}".format(name, ex)) + return _parse_yaml_template(name, response.text) + + if ":" in name: + prefix, rest = name.split(":", 1) + loaders = get_template_loaders() + if prefix not in loaders: + raise LoadTemplateError("Unknown template prefix: {}".format(prefix)) + loader = loaders[prefix] + try: + return loader(rest) + except Exception as ex: + raise LoadTemplateError("Could not load template {}: {}".format(name, ex)) + + path = template_dir() / f"{name}.yaml" + if not path.exists(): + raise LoadTemplateError(f"Invalid template: {name}") + content = path.read_text() + return _parse_yaml_template(name, content) diff --git a/llm/utils.py b/llm/utils.py index 84a012d..686890b 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -240,8 +240,13 @@ def resolve_schema_input(db, schema_input, load_template): return if schema_input.strip().startswith("t:"): name = schema_input.strip()[2:] - template = load_template(name) - if not template.schema_object: + schema_object = None + try: + template = load_template(name) + schema_object = template.schema_object + except ValueError: + raise click.ClickException("Invalid template: {}".format(name)) + if not schema_object: raise click.ClickException("Template '{}' has no schema".format(name)) return template.schema_object if schema_input.strip().startswith("{"): diff --git a/tests/test_templates.py b/tests/test_templates.py index a900046..1bc98e4 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -63,6 +63,7 @@ def test_templates_list(templates_path, args): "system: summarize this\nprompt: $input", "utf-8" ) (templates_path / "sys.yaml").write_text("system: Summarize this", "utf-8") + (templates_path / "invalid.yaml").write_text("system2: This is invalid", "utf-8") runner = CliRunner() result = runner.invoke(cli, args) assert result.exit_code == 0