server: maintain chat completion id for streaming responses

This commit is contained in:
Minsoo Cheong 2024-03-05 17:37:13 +09:00
parent 6a87ac3a52
commit c5efd837b6
2 changed files with 11 additions and 10 deletions

View file

@ -69,7 +69,7 @@ inline static json oaicompat_completion_params_parse(
return llama_params;
}
inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
inline static json format_final_response_oaicompat(const json &request, const task_result &response, std::string id, bool streaming = false)
{
json result = response.result_json;
@ -105,7 +105,7 @@ inline static json format_final_response_oaicompat(const json &request, const ta
json{{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
{"id", gen_chatcmplid()}};
{"id", id}};
if (server_verbose) {
res["__verbose"] = result;
@ -119,7 +119,7 @@ inline static json format_final_response_oaicompat(const json &request, const ta
}
// return value is vector as there is one case where we might need to generate two responses
inline static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
inline static std::vector<json> format_partial_response_oaicompat(const task_result &response, std::string id) {
json result = response.result_json;
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
@ -165,7 +165,7 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
{"role", "assistant"}
}}}})},
{"created", t},
{"id", gen_chatcmplid()},
{"id", id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
@ -176,7 +176,7 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
{"content", content}}}
}})},
{"created", t},
{"id", gen_chatcmplid()},
{"id", id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
@ -202,7 +202,7 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
json ret = json{{"choices", choices},
{"created", t},
{"id", gen_chatcmplid()},
{"id", id},
{"model", modelname},
{"object", "chat.completion.chunk"}};

View file

@ -3210,7 +3210,8 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
const std::string completion_id = gen_chatcmplid();
const auto chat_completions = [&llama, &validate_api_key, &sparams, &completion_id](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
@ -3227,7 +3228,7 @@ int main(int argc, char **argv)
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
json oaicompat_result = format_final_response_oaicompat(data, result, completion_id);
res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
@ -3238,11 +3239,11 @@ int main(int argc, char **argv)
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
const auto chunked_content_provider = [task_id, &llama, &completion_id](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
std::vector<json> result_array = format_partial_response_oaicompat(llama_result, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it)
{