From 4d18da4e1149b69b44a0480729b4e2ef24bc756a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 18 Sep 2023 20:29:39 -0700 Subject: [PATCH] Bump default gpt-3.5-turbo-instruct max tokens to 256, refs #284 --- llm/default_plugins/openai_models.py | 10 +++++++++- tests/test_llm.py | 9 +++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 8b7854e..e76e853 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -23,7 +23,7 @@ def register_models(register): register(Chat("gpt-4"), aliases=("4", "gpt4")) register(Chat("gpt-4-32k"), aliases=("4-32k",)) register( - Completion("gpt-3.5-turbo-instruct"), + Completion("gpt-3.5-turbo-instruct", default_max_tokens=256), aliases=("3.5-instruct", "chatgpt-instruct"), ) # Load extra models @@ -126,6 +126,8 @@ class Chat(Model): key_env_var = "OPENAI_API_KEY" can_stream: bool = True + default_max_tokens = None + class Options(llm.Options): temperature: Optional[float] = Field( description=( @@ -280,6 +282,8 @@ class Chat(Model): def build_kwargs(self, prompt): kwargs = dict(not_nulls(prompt.options)) + if "max_tokens" not in kwargs and self.default_max_tokens is not None: + kwargs["max_tokens"] = self.default_max_tokens if self.api_base: kwargs["api_base"] = self.api_base if self.api_type: @@ -301,6 +305,10 @@ class Chat(Model): class Completion(Chat): + def __init__(self, *args, default_max_tokens=None, **kwargs): + super().__init__(*args, **kwargs) + self.default_max_tokens = default_max_tokens + def __str__(self): return "OpenAI Completion: {}".format(self.model_id) diff --git a/tests/test_llm.py b/tests/test_llm.py index 6b79d47..d049dab 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -313,6 +313,15 @@ def test_openai_completion(mocked_openai_completion, user_path): ) assert result.exit_code == 0 assert result.output == "\n\nThis is indeed a test\n" + + # Should have requested 256 tokens + assert json.loads(mocked_openai_completion.last_request.text) == { + "model": "gpt-3.5-turbo-instruct", + "prompt": "Say this is a test", + "stream": False, + "max_tokens": 256, + } + # Check it was logged rows = list(log_db["responses"].rows) assert len(rows) == 1