server: add OpenAI compatible response format for /completions with backward compatibility
This commit is contained in:
parent
3b4f2e33e2
commit
938dbd4d6e
9 changed files with 470 additions and 186 deletions
|
@ -924,7 +924,7 @@ struct server_context {
|
|||
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
||||
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||
slot.params.sampling.n_probs = json_value(data, "n_probs", json_value(data, "logprobs", defaults.sampling.n_probs));
|
||||
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
|
||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
|
@ -1340,7 +1340,8 @@ struct server_context {
|
|||
}
|
||||
slot.n_sent_token_probs = probs_stop_pos;
|
||||
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
||||
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output, true);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
|
@ -1379,7 +1380,7 @@ struct server_context {
|
|||
{"timings", slot.get_formated_timings()},
|
||||
{"index", slot.index},
|
||||
};
|
||||
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
std::vector<completion_token_output> probs;
|
||||
if (!slot.params.stream && slot.stopped_word) {
|
||||
|
@ -1395,7 +1396,8 @@ struct server_context {
|
|||
slot.generated_token_probs.end());
|
||||
}
|
||||
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
||||
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
|
||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs, true);
|
||||
}
|
||||
|
||||
if (slot.oaicompat) {
|
||||
|
@ -2901,31 +2903,63 @@ int main(int argc, char ** argv) {
|
|||
res_ok(res, {{ "success", true }});
|
||||
};
|
||||
|
||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||
const auto handle_completions_generic = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](server_task_inf_type inf_type, json & data, httplib::Response & res, bool is_chat = false) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse data for /chat/completions format if needed
|
||||
if (is_chat) {
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, data, params.chat_template);
|
||||
}
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
bool oai_compat = json_value(data, "oai_compat", true);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
|
||||
if (is_chat) {
|
||||
// multitask is never supported in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
|
||||
/*.streaming =*/ false, verbose, /*.legacy_format =*/ !is_chat);
|
||||
res_ok(res, result_oai);
|
||||
} else {
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
|
||||
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true);
|
||||
res_ok(res, result_oai);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & result : results) {
|
||||
arr.push_back(format_final_response_oaicompat(data, result.data, completion_id,
|
||||
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true));
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
if (results.size() == 1) {
|
||||
// single result
|
||||
res_ok(res, results[0].data);
|
||||
} else {
|
||||
// multiple results (multitask)
|
||||
json arr = json::array();
|
||||
for (const auto & res : results) {
|
||||
arr.push_back(res.data);
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
|
@ -2933,14 +2967,35 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, is_chat, inf_type, oai_compat](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
|
||||
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, !is_chat);
|
||||
for (auto & event_data : result_array) {
|
||||
if (event_data.empty()) {
|
||||
continue; // skip the stop token
|
||||
}
|
||||
if (!server_sent_event(sink, "data", event_data)) {
|
||||
return false; // connection is closed
|
||||
}
|
||||
}
|
||||
return true; // ok
|
||||
|
||||
}
|
||||
return server_sent_event(sink, "data", result.data);
|
||||
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
|
||||
});
|
||||
|
||||
if (is_chat) {
|
||||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
}
|
||||
sink.done();
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||
|
@ -2953,7 +3008,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, false);
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -3006,63 +3066,6 @@ int main(int argc, char ** argv) {
|
|||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||
};
|
||||
|
||||
// TODO: maybe merge this function with "handle_completions_generic"
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||
ctx_server.queue_tasks.post(tasks);
|
||||
|
||||
bool stream = json_value(data, "stream", false);
|
||||
const auto task_ids = server_task::get_list_id(tasks);
|
||||
const auto completion_id = gen_chatcmplid();
|
||||
|
||||
if (!stream) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, result_oai);
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
} else {
|
||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
||||
for (auto & event_data : result_array) {
|
||||
if (event_data.empty()) {
|
||||
continue; // skip the stop token
|
||||
}
|
||||
if (!server_sent_event(sink, "data", event_data)) {
|
||||
return false; // connection is closed
|
||||
}
|
||||
}
|
||||
return true; // ok
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
static const std::string ev_done = "data: [DONE]\n\n";
|
||||
sink.write(ev_done.data(), ev_done.size());
|
||||
sink.done();
|
||||
return true;
|
||||
};
|
||||
|
||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||
};
|
||||
|
||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
|
|
|
@ -40,9 +40,19 @@ def test_load_split_model():
|
|||
server.model_alias = "tinyllama-split"
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 16,
|
||||
"max_tokens": 16,
|
||||
"prompt": "Hello",
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(little|girl)+", res.body["content"])
|
||||
# Verify response structure
|
||||
assert "id" in res.body
|
||||
assert "object" in res.body
|
||||
assert "created" in res.body
|
||||
assert "model" in res.body
|
||||
assert "choices" in res.body
|
||||
assert isinstance(res.body["choices"], list)
|
||||
assert len(res.body["choices"]) > 0
|
||||
assert "text" in res.body["choices"][0]
|
||||
# Verify the actual content
|
||||
assert match_regex("(little|girl)+", res.body["choices"][0]["text"])
|
||||
|
|
|
@ -18,13 +18,13 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
|||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"max_tokens": n_predict,
|
||||
"prompt": prompt,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||
assert res.body["timings"]["predicted_n"] == n_predicted
|
||||
assert res.body["truncated"] == truncated
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
|
@ -36,16 +36,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"max_tokens": n_predict,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
if data["stop"]:
|
||||
assert data["timings"]["prompt_n"] == n_prompt
|
||||
assert data["timings"]["predicted_n"] == n_predicted
|
||||
assert data["truncated"] == truncated
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
content += data["content"]
|
||||
|
@ -63,6 +63,7 @@ def test_consistent_result_same_seed(n_slots: int):
|
|||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
"oai_compat": False,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
|
@ -81,6 +82,7 @@ def test_different_result_different_seed(n_slots: int):
|
|||
"seed": seed,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
"oai_compat": False,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] != last_res.body["content"]
|
||||
|
@ -100,6 +102,7 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
|
|||
"seed": 42,
|
||||
"temperature": temperature,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
"oai_compat": False,
|
||||
})
|
||||
if last_res is not None:
|
||||
assert res.body["content"] == last_res.body["content"]
|
||||
|
@ -115,12 +118,14 @@ def test_cache_vs_nocache_prompt():
|
|||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
res_no_cache = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||
|
||||
|
@ -140,6 +145,7 @@ def test_completion_with_tokens_input():
|
|||
# single completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": tokens,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
@ -147,6 +153,7 @@ def test_completion_with_tokens_input():
|
|||
# batch completion
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, tokens],
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
|
@ -156,6 +163,7 @@ def test_completion_with_tokens_input():
|
|||
# mixed string and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [tokens, prompt_str],
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
|
@ -165,6 +173,7 @@ def test_completion_with_tokens_input():
|
|||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body["content"]) == str
|
||||
|
@ -208,6 +217,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||
"prompt": prompt,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"oai_compat": False,
|
||||
})))
|
||||
tasks.append((check_slots_status, ()))
|
||||
results = parallel_function_calls(tasks)
|
||||
|
@ -221,3 +231,122 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||
assert len(res.body["content"]) > 10
|
||||
# FIXME: the result is not deterministic when using other slot than slot 0
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
# OpenAI legacy completion endpoint tests
|
||||
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||
])
|
||||
def test_completion_openai(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# Test non-streaming response
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"model": "local-model",
|
||||
"prompt": prompt,
|
||||
"max_tokens": n_predict,
|
||||
"logprobs": 3,
|
||||
"echo": True
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["object"] == "text_completion"
|
||||
assert isinstance(res.body["id"], str)
|
||||
assert isinstance(res.body["created"], int)
|
||||
assert res.body["model"] == "local-model"
|
||||
|
||||
# Check choices array
|
||||
assert len(res.body["choices"]) == 1
|
||||
choice = res.body["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert isinstance(choice["text"], str)
|
||||
assert choice["finish_reason"] in ["stop", "length"]
|
||||
|
||||
# Check logprobs
|
||||
assert choice["logprobs"] is not None
|
||||
assert "tokens" in choice["logprobs"]
|
||||
assert "token_logprobs" in choice["logprobs"]
|
||||
assert "top_logprobs" in choice["logprobs"]
|
||||
assert len(choice["logprobs"]["top_logprobs"]) == len(choice["logprobs"]["tokens"])
|
||||
|
||||
# Check usage statistics
|
||||
assert "usage" in res.body
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||
])
|
||||
def test_completion_openai_stream(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
res = server.make_stream_request("POST", "/v1/completions", data={
|
||||
"prompt": prompt,
|
||||
"max_tokens": n_predict,
|
||||
"stream": True,
|
||||
})
|
||||
|
||||
collected_text = ""
|
||||
is_first_chunk = True
|
||||
for data in res:
|
||||
assert "id" in data
|
||||
assert data["object"] == "text_completion"
|
||||
assert isinstance(data["created"], int)
|
||||
|
||||
assert len(data["choices"]) == 1
|
||||
choice = data["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert isinstance(choice["text"], str)
|
||||
collected_text += choice["text"]
|
||||
|
||||
if is_first_chunk:
|
||||
# First chunk should have model info
|
||||
is_first_chunk = False
|
||||
|
||||
if choice["finish_reason"] is not None:
|
||||
# This is the last chunk
|
||||
assert choice["finish_reason"] in ["stop", "length"]
|
||||
assert "usage" in data
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
assert data["usage"]["total_tokens"] == n_prompt + n_predicted
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||
])
|
||||
def test_completion_openai_no_logprobs(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
# Test non-streaming response
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": prompt,
|
||||
"max_tokens": n_predict,
|
||||
"echo": True
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.body["object"] == "text_completion"
|
||||
assert isinstance(res.body["id"], str)
|
||||
assert isinstance(res.body["created"], int)
|
||||
|
||||
# Check choices array
|
||||
assert len(res.body["choices"]) == 1
|
||||
choice = res.body["choices"][0]
|
||||
assert choice["index"] == 0
|
||||
assert isinstance(choice["text"], str)
|
||||
assert choice["finish_reason"] in ["stop", "length"]
|
||||
|
||||
# Verify logprobs is None when not requested
|
||||
assert choice["logprobs"] is None
|
||||
|
||||
# Check usage statistics
|
||||
assert "usage" in res.body
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted
|
|
@ -29,6 +29,7 @@ def test_ctx_shift_enabled():
|
|||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["prompt_n"] == 109
|
||||
|
@ -48,6 +49,7 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
|
|||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": "Hi how are you",
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["timings"]["predicted_n"] == n_token_output
|
||||
|
@ -61,6 +63,7 @@ def test_ctx_shift_disabled_long_prompt():
|
|||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": LONG_TEXT,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code != 200
|
||||
assert "error" in res.body
|
||||
|
|
|
@ -36,6 +36,7 @@ def test_lora(scale: float, re_content: str):
|
|||
assert res_lora_control.status_code == 200
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Look in thy glass",
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
|
|
@ -41,6 +41,7 @@ def test_correct_api_key():
|
|||
server.start()
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"oai_compat": False,
|
||||
}, headers={
|
||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||
})
|
||||
|
|
|
@ -20,6 +20,7 @@ def test_slot_save_restore():
|
|||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
|
@ -37,6 +38,7 @@ def test_slot_save_restore():
|
|||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
|
@ -54,6 +56,7 @@ def test_slot_save_restore():
|
|||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 0,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
|
@ -64,6 +67,7 @@ def test_slot_save_restore():
|
|||
"prompt": "What is the capital of Germany?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Jack|said)+", res.body["content"])
|
||||
|
@ -78,6 +82,7 @@ def test_slot_erase():
|
|||
"prompt": "What is the capital of France?",
|
||||
"id_slot": 1,
|
||||
"cache_prompt": True,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
|
@ -94,5 +99,5 @@ def test_slot_erase():
|
|||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||
assert match_regex("(Whiskers|Flana)+", res.body["choices"][0]["text"])
|
||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||
|
|
|
@ -37,6 +37,7 @@ def test_with_and_without_draft():
|
|||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
content_no_draft = res.body["content"]
|
||||
|
@ -49,6 +50,7 @@ def test_with_and_without_draft():
|
|||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
content_draft = res.body["content"]
|
||||
|
@ -75,6 +77,7 @@ def test_different_draft_min_draft_max():
|
|||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"oai_compat": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
if last_content is not None:
|
||||
|
@ -96,6 +99,7 @@ def test_multi_requests_parallel(n_slots: int, n_requests: int):
|
|||
"prompt": "I believe the meaning of life is",
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"oai_compat": False,
|
||||
})))
|
||||
results = parallel_function_calls(tasks)
|
||||
for res in results:
|
||||
|
|
|
@ -498,28 +498,105 @@ struct completion_token_output {
|
|||
};
|
||||
|
||||
// convert a vector of completion_token_output to json
|
||||
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
|
||||
json out = json::array();
|
||||
|
||||
for (const auto & prob : probs) {
|
||||
json probs_for_token = json::array();
|
||||
|
||||
for (const auto & p : prob.probs) {
|
||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
||||
probs_for_token.push_back(json {
|
||||
{"tok_str", tok_str},
|
||||
{"prob", p.prob},
|
||||
});
|
||||
static json probs_vector_to_json(llama_context * ctx, const std::vector<completion_token_output> & probs, bool legacy_format = true) {
|
||||
if (legacy_format) {
|
||||
// Legacy format (text_completion endpoint)
|
||||
json logprobs;
|
||||
std::vector<std::string> tokens;
|
||||
std::vector<json> token_logprobs; // Changed to json to allow null values
|
||||
std::vector<json> top_logprobs; // Changed to allow null values
|
||||
std::vector<int> text_offset;
|
||||
|
||||
int current_offset = 0;
|
||||
|
||||
for (const auto & prob : probs) {
|
||||
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||
tokens.push_back(token_str);
|
||||
text_offset.push_back(current_offset);
|
||||
|
||||
// Handle token logprobs
|
||||
if (!prob.probs.empty() && prob.probs[0].prob > 0) {
|
||||
token_logprobs.push_back(std::log(prob.probs[0].prob));
|
||||
} else {
|
||||
token_logprobs.push_back(nullptr);
|
||||
}
|
||||
|
||||
// Handle top logprobs
|
||||
json token_top_logprobs = json::object();
|
||||
for (const auto & p : prob.probs) {
|
||||
if (p.prob > 0) {
|
||||
token_top_logprobs[tokens_to_output_formatted_string(ctx, p.tok)] = std::log(p.prob);
|
||||
}
|
||||
}
|
||||
top_logprobs.push_back(token_top_logprobs.empty() ? nullptr : token_top_logprobs);
|
||||
|
||||
current_offset += token_str.length();
|
||||
}
|
||||
|
||||
logprobs = {
|
||||
{"tokens", tokens},
|
||||
{"token_logprobs", token_logprobs},
|
||||
{"top_logprobs", top_logprobs},
|
||||
{"text_offset", text_offset}
|
||||
};
|
||||
|
||||
return logprobs;
|
||||
} else {
|
||||
// New format (GPT-4 style)
|
||||
json logprobs;
|
||||
std::vector<json> content;
|
||||
|
||||
for (const auto & prob : probs) {
|
||||
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||
|
||||
// Create top_logprobs array for this token
|
||||
json token_top_logprobs = json::array();
|
||||
for (const auto & p : prob.probs) {
|
||||
if (p.prob > 0) {
|
||||
// Get UTF-8 bytes for the token
|
||||
std::vector<int> bytes;
|
||||
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
||||
for (unsigned char c : tok_str) {
|
||||
bytes.push_back(static_cast<int>(c));
|
||||
}
|
||||
|
||||
json entry = {
|
||||
{"token", tok_str},
|
||||
{"logprob", std::log(p.prob)},
|
||||
{"bytes", bytes.empty() ? json(nullptr) : json(bytes)}
|
||||
};
|
||||
token_top_logprobs.push_back(entry);
|
||||
}
|
||||
}
|
||||
|
||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||
out.push_back(json {
|
||||
{"content", tok_str},
|
||||
{"probs", probs_for_token},
|
||||
});
|
||||
// Get main token logprob
|
||||
float main_logprob = (!prob.probs.empty() && prob.probs[0].prob > 0)
|
||||
? std::log(prob.probs[0].prob)
|
||||
: -9999.0f;
|
||||
|
||||
// Get UTF-8 bytes for the main token
|
||||
std::vector<int> main_bytes;
|
||||
for (unsigned char c : token_str) {
|
||||
main_bytes.push_back(static_cast<int>(c));
|
||||
}
|
||||
|
||||
// Add token info to content array
|
||||
json token_info = {
|
||||
{"token", token_str},
|
||||
{"logprob", main_logprob},
|
||||
{"bytes", main_bytes.empty() ? json(nullptr) : json(main_bytes)},
|
||||
{"top_logprobs", token_top_logprobs}
|
||||
};
|
||||
content.push_back(token_info);
|
||||
}
|
||||
|
||||
logprobs = {
|
||||
{"content", content},
|
||||
{"refusal", nullptr} // Add refusal field as null since we don't implement content filtering
|
||||
};
|
||||
|
||||
return logprobs;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
|
||||
|
@ -540,7 +617,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
|||
static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const std::string & chat_template) {
|
||||
const std::string & chat_template
|
||||
) {
|
||||
json llama_params;
|
||||
|
||||
llama_params["__oaicompat"] = true;
|
||||
|
@ -604,43 +682,71 @@ static json oaicompat_completion_params_parse(
|
|||
return llama_params;
|
||||
}
|
||||
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool legacy_format = false) {
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
bool truncated = json_value(result, "truncated", false);
|
||||
|
||||
std::string finish_reason = "length";
|
||||
if (stopped_word || stopped_eos) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices =
|
||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}})
|
||||
: json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{{"content", content},
|
||||
{"role", "assistant"}}}}});
|
||||
json choices;
|
||||
// Use the pre-formatted logprobs directly
|
||||
json logprobs = result.contains("completion_probabilities") ?
|
||||
result["completion_probabilities"] : nullptr;
|
||||
if (legacy_format) {
|
||||
|
||||
choices = json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"logprobs", logprobs},
|
||||
{"text", content}
|
||||
}});
|
||||
} else {
|
||||
// Format for chat completions endpoint
|
||||
choices = streaming ?
|
||||
json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
}}) :
|
||||
json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", json{
|
||||
{"content", content},
|
||||
{"role", "assistant"}
|
||||
}}
|
||||
}});
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json res = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"model",
|
||||
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", legacy_format ? "text_completion" :
|
||||
(streaming ? "chat.completion.chunk" : "chat.completion")},
|
||||
{"usage", json {
|
||||
{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
}},
|
||||
{"id", completion_id}
|
||||
{"id", completion_id},
|
||||
{"truncated", truncated}
|
||||
};
|
||||
|
||||
// Add system_fingerprint if provided
|
||||
if (result.contains("system_fingerprint")) {
|
||||
res["system_fingerprint"] = result["system_fingerprint"];
|
||||
}
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
res["__verbose"] = result;
|
||||
|
@ -658,105 +764,127 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||
}
|
||||
|
||||
// return value is vector as there is one case where we might need to generate two responses
|
||||
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
|
||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||
return std::vector<json>({result});
|
||||
}
|
||||
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool legacy_format = false) {
|
||||
// Early return if required fields are missing
|
||||
// if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||
// return std::vector<json>({result});
|
||||
// }
|
||||
|
||||
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
||||
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
bool stopped_word = json_value(result, "stopped_word", false);
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
||||
std::string content = json_value(result, "content", std::string(""));
|
||||
|
||||
// Determine finish reason
|
||||
std::string finish_reason;
|
||||
if (stopped_word || stopped_eos) {
|
||||
if (json_value(result, "stopped_word", false) || json_value(result, "stopped_eos", false)) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
if (stopped_limit) {
|
||||
if (json_value(result, "stopped_limit", false)) {
|
||||
finish_reason = "length";
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json choices;
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
choices = json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}});
|
||||
} else {
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}}});
|
||||
} else {
|
||||
// We have to send this as two updates to conform to openai behavior
|
||||
json initial_ret = json{{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"role", "assistant"}
|
||||
}}}})},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
json second_ret = json{
|
||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
// Final message with finish reason
|
||||
if (legacy_format) {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"logprobs", result.contains("completion_probabilities") ?
|
||||
result["completion_probabilities"] : nullptr},
|
||||
{"text", content}
|
||||
}});
|
||||
} else {
|
||||
// Some idiosyncrasy in task processing logic makes several trailing calls
|
||||
// with empty content, we ignore these at the calee site.
|
||||
if (content.empty()) {
|
||||
return std::vector<json>({json::object()});
|
||||
}
|
||||
|
||||
choices = json::array({json{
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
}});
|
||||
}
|
||||
} else {
|
||||
// Content message
|
||||
if (legacy_format) {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json{
|
||||
{"content", content},
|
||||
}},
|
||||
{"logprobs", result.contains("completion_probabilities") ?
|
||||
result["completion_probabilities"] : nullptr},
|
||||
{"text", content}
|
||||
}});
|
||||
} else {
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}
|
||||
}});
|
||||
} else {
|
||||
if (content.empty()) {
|
||||
return std::vector<json>({json::object()});
|
||||
}
|
||||
|
||||
// Split into two messages for first content in chat mode
|
||||
json initial_ret = json{
|
||||
{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}
|
||||
};
|
||||
|
||||
json second_ret = json{
|
||||
{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}
|
||||
};
|
||||
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
} else {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"content", content}}}
|
||||
}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
// Construct the response
|
||||
json ret = json{
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", "chat.completion.chunk"}
|
||||
{"id", completion_id},
|
||||
{"model", modelname},
|
||||
{"object", legacy_format ? "text_completion" : "chat.completion.chunk"}
|
||||
};
|
||||
|
||||
// Add timings if present
|
||||
if (result.contains("timings")) {
|
||||
ret.push_back({"timings", json_value(result, "timings", json::object())});
|
||||
ret["timings"] = json_value(result, "timings", json::object());
|
||||
}
|
||||
|
||||
// Add usage statistics for final messages
|
||||
if (!finish_reason.empty()) {
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
ret.push_back({"usage", json {
|
||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||
ret["usage"] = json{
|
||||
{"completion_tokens", num_tokens_predicted},
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
}});
|
||||
{"prompt_tokens", num_prompt_tokens},
|
||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||
};
|
||||
}
|
||||
|
||||
return std::vector<json>({ret});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue