mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-26 07:54:45 +00:00
Refactor Template into templates.py
This commit is contained in:
parent
5e056fad8a
commit
13fb4c2966
2 changed files with 57 additions and 55 deletions
|
|
@ -1,56 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
import string
|
||||
from typing import Optional
|
||||
from .hookspecs import hookimpl # noqa
|
||||
from .hookspecs import hookspec # noqa
|
||||
from .models import Model, Prompt, Response, OptionsError # noqa
|
||||
from .hookspecs import hookimpl
|
||||
from .models import Model, Prompt, Response, OptionsError
|
||||
from .templates import Template
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
name: str
|
||||
prompt: Optional[str]
|
||||
system: Optional[str]
|
||||
model: Optional[str]
|
||||
defaults: Optional[dict]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
class MissingVariables(Exception):
|
||||
pass
|
||||
|
||||
def execute(self, input, params=None):
|
||||
params = params or {}
|
||||
params["input"] = input
|
||||
if self.defaults:
|
||||
for k, v in self.defaults.items():
|
||||
if k not in params:
|
||||
params[k] = v
|
||||
if not self.prompt:
|
||||
system = self.interpolate(self.system, params)
|
||||
prompt = input
|
||||
else:
|
||||
prompt = self.interpolate(self.prompt, params)
|
||||
system = self.interpolate(self.system, params)
|
||||
return prompt, system
|
||||
|
||||
@classmethod
|
||||
def interpolate(cls, text, params):
|
||||
if not text:
|
||||
return text
|
||||
# Confirm all variables in text are provided
|
||||
string_template = string.Template(text)
|
||||
vars = cls.extract_vars(string_template)
|
||||
missing = [p for p in vars if p not in params]
|
||||
if missing:
|
||||
raise cls.MissingVariables(
|
||||
"Missing variables: {}".format(", ".join(missing))
|
||||
)
|
||||
return string_template.substitute(**params)
|
||||
|
||||
@staticmethod
|
||||
def extract_vars(string_template):
|
||||
return [
|
||||
match.group("named")
|
||||
for match in string_template.pattern.finditer(string_template.template)
|
||||
]
|
||||
__all__ = ["Template", "Model", "Prompt", "Response", "OptionsError", "hookimpl"]
|
||||
|
|
|
|||
53
llm/templates.py
Normal file
53
llm/templates.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from pydantic import BaseModel
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
name: str
|
||||
prompt: Optional[str]
|
||||
system: Optional[str]
|
||||
model: Optional[str]
|
||||
defaults: Optional[dict]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
class MissingVariables(Exception):
|
||||
pass
|
||||
|
||||
def execute(self, input, params=None):
|
||||
params = params or {}
|
||||
params["input"] = input
|
||||
if self.defaults:
|
||||
for k, v in self.defaults.items():
|
||||
if k not in params:
|
||||
params[k] = v
|
||||
if not self.prompt:
|
||||
system = self.interpolate(self.system, params)
|
||||
prompt = input
|
||||
else:
|
||||
prompt = self.interpolate(self.prompt, params)
|
||||
system = self.interpolate(self.system, params)
|
||||
return prompt, system
|
||||
|
||||
@classmethod
|
||||
def interpolate(cls, text, params):
|
||||
if not text:
|
||||
return text
|
||||
# Confirm all variables in text are provided
|
||||
string_template = string.Template(text)
|
||||
vars = cls.extract_vars(string_template)
|
||||
missing = [p for p in vars if p not in params]
|
||||
if missing:
|
||||
raise cls.MissingVariables(
|
||||
"Missing variables: {}".format(", ".join(missing))
|
||||
)
|
||||
return string_template.substitute(**params)
|
||||
|
||||
@staticmethod
|
||||
def extract_vars(string_template):
|
||||
return [
|
||||
match.group("named")
|
||||
for match in string_template.pattern.finditer(string_template.template)
|
||||
]
|
||||
Loading…
Reference in a new issue