diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 885497081..130da03a1 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte def test_chat_completion_with_openai_library(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", messages=[ @@ -170,7 +170,7 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, @@ -197,7 +197,7 @@ def test_logprobs(): def test_logprobs_stream(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index a6b215944..e5e3b6077 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -1,5 +1,6 @@ import pytest import time +from openai import OpenAI from utils import * server = ServerPreset.tinyllama2() @@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream(): assert content_stream == res_non_stream.body["content"] +def test_completion_stream_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].text is not None + assert match_regex("(going|bed)+", res.choices[0].text) + + +def test_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("(going|bed)+", output_text) + + @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 640cf6a29..8523d4787 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -570,8 +570,11 @@ static json oaicompat_completion_params_parse(const json & body) { } // Params supported by OAI but unsupported by llama.cpp - if (body.contains("best_of")) { - throw std::runtime_error("Unsupported param: best_of"); + static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } } // Copy remaining properties to llama_params