add test
This commit is contained in:
parent
90889fddc9
commit
36033990d1
3 changed files with 43 additions and 5 deletions
|
@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||||
def test_chat_completion_with_openai_library():
|
def test_chat_completion_with_openai_library():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -170,7 +170,7 @@ def test_chat_completion_with_timings_per_token():
|
||||||
def test_logprobs():
|
def test_logprobs():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
@ -197,7 +197,7 @@ def test_logprobs():
|
||||||
def test_logprobs_stream():
|
def test_logprobs_stream():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream():
|
||||||
assert content_stream == res_non_stream.body["content"]
|
assert content_stream == res_non_stream.body["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_stream_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
|
res = client.completions.create(
|
||||||
|
model="davinci-002",
|
||||||
|
prompt="I believe the meaning of life is",
|
||||||
|
max_tokens=8,
|
||||||
|
)
|
||||||
|
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||||
|
assert res.choices[0].finish_reason == "length"
|
||||||
|
assert res.choices[0].text is not None
|
||||||
|
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
|
res = client.completions.create(
|
||||||
|
model="davinci-002",
|
||||||
|
prompt="I believe the meaning of life is",
|
||||||
|
max_tokens=8,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
output_text = ''
|
||||||
|
for data in res:
|
||||||
|
choice = data.choices[0]
|
||||||
|
if choice.finish_reason is None:
|
||||||
|
assert choice.text is not None
|
||||||
|
output_text += choice.text
|
||||||
|
assert match_regex("(going|bed)+", output_text)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||||
def test_consistent_result_same_seed(n_slots: int):
|
def test_consistent_result_same_seed(n_slots: int):
|
||||||
global server
|
global server
|
||||||
|
|
|
@ -570,8 +570,11 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
if (body.contains("best_of")) {
|
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
|
||||||
throw std::runtime_error("Unsupported param: best_of");
|
for (const auto & param : unsupported_params) {
|
||||||
|
if (body.contains(param)) {
|
||||||
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy remaining properties to llama_params
|
// Copy remaining properties to llama_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue