mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-21 05:31:01 +00:00
Model.stream() and .get_key() methods
This commit is contained in:
parent
2911975548
commit
ffe4b6706d
2 changed files with 23 additions and 6 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, Optional, Set
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -54,6 +55,7 @@ class Response(ABC):
|
|||
|
||||
class Model(ABC):
|
||||
model_id: str
|
||||
key: Optional[str] = None
|
||||
needs_key: Optional[str] = None
|
||||
key_env_var: Optional[str] = None
|
||||
can_stream: bool = False
|
||||
|
|
@ -62,10 +64,25 @@ class Model(ABC):
|
|||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def prompt(self, prompt, system=None, stream=True, **options):
|
||||
def get_key(self):
|
||||
if self.needs_key is None:
|
||||
return None
|
||||
if self.key is not None:
|
||||
return self.key
|
||||
if self.key_env_var is not None:
|
||||
return os.environ.get(self.key_env_var)
|
||||
return None
|
||||
|
||||
def prompt(self, prompt, system=None, **options):
|
||||
return self.execute(
|
||||
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
|
||||
stream=stream,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def stream(self, prompt, system=None, **options):
|
||||
return self.execute(
|
||||
Prompt(prompt, system=system, model=self, options=self.Options(**options)),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -53,17 +53,17 @@ class Chat(Model):
|
|||
key_env_var = "OPENAI_API_KEY"
|
||||
can_stream: bool = True
|
||||
|
||||
def __init__(self, model_id, key=None, stream=True):
|
||||
def __init__(self, model_id, key=None):
|
||||
self.model_id = model_id
|
||||
self.stream = stream
|
||||
self.key = key
|
||||
|
||||
def execute(self, prompt: Prompt, stream: bool = True) -> ChatResponse:
|
||||
if self.key is None:
|
||||
key = self.get_key()
|
||||
if key is None:
|
||||
raise NeedsKeyException(
|
||||
"{} needs an API key, label={}".format(str(self), self.needs_key)
|
||||
)
|
||||
return ChatResponse(prompt, stream, key=self.key)
|
||||
return ChatResponse(prompt, stream, key=key)
|
||||
|
||||
def __str__(self):
|
||||
return "OpenAI Chat: {}".format(self.model_id)
|
||||
|
|
|
|||
Loading…
Reference in a new issue