mirror of
https://github.com/Hopiu/llm.git
synced 2026-03-24 16:30:25 +00:00
Bump default gpt-3.5-turbo-instruct max tokens to 256, refs #284
This commit is contained in:
parent
4d46ebaa32
commit
4d18da4e11
2 changed files with 18 additions and 1 deletions
|
|
@ -23,7 +23,7 @@ def register_models(register):
|
|||
register(Chat("gpt-4"), aliases=("4", "gpt4"))
|
||||
register(Chat("gpt-4-32k"), aliases=("4-32k",))
|
||||
register(
|
||||
Completion("gpt-3.5-turbo-instruct"),
|
||||
Completion("gpt-3.5-turbo-instruct", default_max_tokens=256),
|
||||
aliases=("3.5-instruct", "chatgpt-instruct"),
|
||||
)
|
||||
# Load extra models
|
||||
|
|
@ -126,6 +126,8 @@ class Chat(Model):
|
|||
key_env_var = "OPENAI_API_KEY"
|
||||
can_stream: bool = True
|
||||
|
||||
default_max_tokens = None
|
||||
|
||||
class Options(llm.Options):
|
||||
temperature: Optional[float] = Field(
|
||||
description=(
|
||||
|
|
@ -280,6 +282,8 @@ class Chat(Model):
|
|||
|
||||
def build_kwargs(self, prompt):
|
||||
kwargs = dict(not_nulls(prompt.options))
|
||||
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
|
||||
kwargs["max_tokens"] = self.default_max_tokens
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.api_type:
|
||||
|
|
@ -301,6 +305,10 @@ class Chat(Model):
|
|||
|
||||
|
||||
class Completion(Chat):
|
||||
def __init__(self, *args, default_max_tokens=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.default_max_tokens = default_max_tokens
|
||||
|
||||
def __str__(self):
|
||||
return "OpenAI Completion: {}".format(self.model_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -313,6 +313,15 @@ def test_openai_completion(mocked_openai_completion, user_path):
|
|||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == "\n\nThis is indeed a test\n"
|
||||
|
||||
# Should have requested 256 tokens
|
||||
assert json.loads(mocked_openai_completion.last_request.text) == {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"stream": False,
|
||||
"max_tokens": 256,
|
||||
}
|
||||
|
||||
# Check it was logged
|
||||
rows = list(log_db["responses"].rows)
|
||||
assert len(rows) == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue