mirror of
https://github.com/Hopiu/llm.git
synced 2026-04-23 14:34:46 +00:00
OpenAI completion models including gpt-3.5-turbo-instruct, refs #284
This commit is contained in:
parent
356fcb72f6
commit
4d46ebaa32
5 changed files with 132 additions and 26 deletions
|
|
@ -22,6 +22,10 @@ def register_models(register):
|
|||
register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k"))
|
||||
register(Chat("gpt-4"), aliases=("4", "gpt4"))
|
||||
register(Chat("gpt-4-32k"), aliases=("4-32k",))
|
||||
register(
|
||||
Completion("gpt-3.5-turbo-instruct"),
|
||||
aliases=("3.5-instruct", "chatgpt-instruct"),
|
||||
)
|
||||
# Load extra models
|
||||
extra_path = llm.user_dir() / "extra-openai-models.yaml"
|
||||
if not extra_path.exists():
|
||||
|
|
@ -249,24 +253,7 @@ class Chat(Model):
|
|||
messages.append({"role": "system", "content": prompt.system})
|
||||
messages.append({"role": "user", "content": prompt.prompt})
|
||||
response._prompt_json = {"messages": messages}
|
||||
kwargs = dict(not_nulls(prompt.options))
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.api_type:
|
||||
kwargs["api_type"] = self.api_type
|
||||
if self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_engine:
|
||||
kwargs["engine"] = self.api_engine
|
||||
if self.needs_key:
|
||||
if self.key:
|
||||
kwargs["api_key"] = self.key
|
||||
else:
|
||||
# OpenAI-compatible models don't need a key, but the
|
||||
# openai client library requires one
|
||||
kwargs["api_key"] = "DUMMY_KEY"
|
||||
if self.headers:
|
||||
kwargs["headers"] = self.headers
|
||||
kwargs = self.build_kwargs(prompt)
|
||||
if stream:
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name or self.model_id,
|
||||
|
|
@ -291,6 +278,65 @@ class Chat(Model):
|
|||
response.response_json = completion.to_dict_recursive()
|
||||
yield completion.choices[0].message.content
|
||||
|
||||
def build_kwargs(self, prompt):
|
||||
kwargs = dict(not_nulls(prompt.options))
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.api_type:
|
||||
kwargs["api_type"] = self.api_type
|
||||
if self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if self.api_engine:
|
||||
kwargs["engine"] = self.api_engine
|
||||
if self.needs_key:
|
||||
if self.key:
|
||||
kwargs["api_key"] = self.key
|
||||
else:
|
||||
# OpenAI-compatible models don't need a key, but the
|
||||
# openai client library requires one
|
||||
kwargs["api_key"] = "DUMMY_KEY"
|
||||
if self.headers:
|
||||
kwargs["headers"] = self.headers
|
||||
return kwargs
|
||||
|
||||
|
||||
class Completion(Chat):
|
||||
def __str__(self):
|
||||
return "OpenAI Completion: {}".format(self.model_id)
|
||||
|
||||
def execute(self, prompt, stream, response, conversation=None):
|
||||
messages = []
|
||||
if conversation is not None:
|
||||
for prev_response in conversation.responses:
|
||||
messages.append(prev_response.prompt.prompt)
|
||||
messages.append(prev_response.text())
|
||||
messages.append(prompt.prompt)
|
||||
response._prompt_json = {"messages": messages}
|
||||
kwargs = self.build_kwargs(prompt)
|
||||
if stream:
|
||||
completion = openai.Completion.create(
|
||||
model=self.model_name or self.model_id,
|
||||
prompt="\n".join(messages),
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
chunks = []
|
||||
for chunk in completion:
|
||||
chunks.append(chunk)
|
||||
content = chunk["choices"][0].get("text") or ""
|
||||
if content is not None:
|
||||
yield content
|
||||
response.response_json = combine_chunks(chunks)
|
||||
else:
|
||||
completion = openai.Completion.create(
|
||||
model=self.model_name or self.model_id,
|
||||
prompt="\n".join(messages),
|
||||
stream=False,
|
||||
**kwargs,
|
||||
)
|
||||
response.response_json = completion.to_dict_recursive()
|
||||
yield completion.choices[0]["text"]
|
||||
|
||||
|
||||
def not_nulls(data) -> dict:
|
||||
return {key: value for key, value in data if value is not None}
|
||||
|
|
@ -303,6 +349,9 @@ def combine_chunks(chunks: List[dict]) -> dict:
|
|||
|
||||
for item in chunks:
|
||||
for choice in item["choices"]:
|
||||
if "text" in choice and "delta" not in choice:
|
||||
content += choice["text"]
|
||||
continue
|
||||
if "role" in choice["delta"]:
|
||||
role = choice["delta"]["role"]
|
||||
if "content" in choice["delta"]:
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ def register_embed_demo_model(embed_demo, mock_model):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_openai(requests_mock):
|
||||
def mocked_openai_chat(requests_mock):
|
||||
return requests_mock.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -150,6 +150,29 @@ def mocked_openai(requests_mock):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_openai_completion(requests_mock):
|
||||
return requests_mock.post(
|
||||
"https://api.openai.com/v1/completions",
|
||||
json={
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"choices": [
|
||||
{
|
||||
"text": "\n\nThis is indeed a test",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12},
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_localai(requests_mock):
|
||||
return requests_mock.post(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def test_keys_list(monkeypatch, tmpdir, args):
|
|||
assert result2.output.strip() == "openai"
|
||||
|
||||
|
||||
def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir):
|
||||
def test_uses_correct_key(mocked_openai_chat, monkeypatch, tmpdir):
|
||||
user_dir = tmpdir / "user-dir"
|
||||
pathlib.Path(user_dir).mkdir()
|
||||
keys_path = user_dir / "keys.json"
|
||||
|
|
@ -57,7 +57,7 @@ def test_uses_correct_key(mocked_openai, monkeypatch, tmpdir):
|
|||
monkeypatch.setenv("OPENAI_API_KEY", "from-env")
|
||||
|
||||
def assert_key(key):
|
||||
assert mocked_openai.last_request.headers[
|
||||
assert mocked_openai_chat.last_request.headers[
|
||||
"Authorization"
|
||||
] == "Bearer {}".format(key)
|
||||
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ def test_logs_search(user_path, query, expected):
|
|||
assert [record["id"] for record in records] == expected
|
||||
|
||||
|
||||
def test_llm_prompt_creates_log_database(mocked_openai, tmpdir, monkeypatch):
|
||||
def test_llm_prompt_creates_log_database(mocked_openai_chat, tmpdir, monkeypatch):
|
||||
user_path = tmpdir / "user"
|
||||
monkeypatch.setenv("LLM_USER_PATH", str(user_path))
|
||||
runner = CliRunner()
|
||||
|
|
@ -198,7 +198,7 @@ def test_llm_prompt_creates_log_database(mocked_openai, tmpdir, monkeypatch):
|
|||
),
|
||||
)
|
||||
def test_llm_default_prompt(
|
||||
mocked_openai, use_stdin, user_path, logs_off, logs_args, should_log
|
||||
mocked_openai_chat, use_stdin, user_path, logs_off, logs_args, should_log
|
||||
):
|
||||
# Reset the log_path database
|
||||
log_path = user_path / "logs.db"
|
||||
|
|
@ -232,7 +232,7 @@ def test_llm_default_prompt(
|
|||
result = runner.invoke(cli, args, input=input, catch_exceptions=False)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == "Bob, Alice, Eve\n"
|
||||
assert mocked_openai.last_request.headers["Authorization"] == "Bearer X"
|
||||
assert mocked_openai_chat.last_request.headers["Authorization"] == "Bearer X"
|
||||
|
||||
# Was it logged?
|
||||
rows = list(log_db["responses"].rows)
|
||||
|
|
@ -294,6 +294,40 @@ def test_llm_default_prompt(
|
|||
)
|
||||
|
||||
|
||||
def test_openai_completion(mocked_openai_completion, user_path):
|
||||
log_path = user_path / "logs.db"
|
||||
log_db = sqlite_utils.Database(str(log_path))
|
||||
log_db["responses"].delete_where()
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"-m",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"Say this is a test",
|
||||
"--no-stream",
|
||||
"--key",
|
||||
"x",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == "\n\nThis is indeed a test\n"
|
||||
# Check it was logged
|
||||
rows = list(log_db["responses"].rows)
|
||||
assert len(rows) == 1
|
||||
expected = {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"system": None,
|
||||
"prompt_json": '{"messages": ["Say this is a test"]}',
|
||||
"options_json": "{}",
|
||||
"response": "\n\nThis is indeed a test",
|
||||
}
|
||||
row = rows[0]
|
||||
assert expected.items() <= row.items()
|
||||
|
||||
|
||||
EXTRA_MODELS_YAML = """
|
||||
- model_id: orca
|
||||
model_name: orca-mini-3b
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ def test_templates_prompt_save(templates_path, args, expected_prompt, expected_e
|
|||
)
|
||||
def test_template_basic(
|
||||
templates_path,
|
||||
mocked_openai,
|
||||
mocked_openai_chat,
|
||||
template,
|
||||
extra_args,
|
||||
expected_model,
|
||||
|
|
@ -173,7 +173,7 @@ def test_template_basic(
|
|||
)
|
||||
if expected_error is None:
|
||||
assert result.exit_code == 0
|
||||
assert mocked_openai.last_request.json() == {
|
||||
assert mocked_openai_chat.last_request.json() == {
|
||||
"model": expected_model,
|
||||
"messages": [{"role": "user", "content": expected_input}],
|
||||
"stream": False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue