From c5efd837b61f3c46cacd0fe8567e6a098a7b0766 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Tue, 5 Mar 2024 17:37:13 +0900 Subject: [PATCH] server: maintain chat completion id for streaming responses --- examples/server/oai.hpp | 12 ++++++------ examples/server/server.cpp | 9 +++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index ff4ad6994..50ca48f93 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -69,7 +69,7 @@ inline static json oaicompat_completion_params_parse( return llama_params; } -inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) +inline static json format_final_response_oaicompat(const json &request, const task_result &response, std::string id, bool streaming = false) { json result = response.result_json; @@ -105,7 +105,7 @@ inline static json format_final_response_oaicompat(const json &request, const ta json{{"completion_tokens", num_tokens_predicted}, {"prompt_tokens", num_prompt_tokens}, {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; + {"id", id}}; if (server_verbose) { res["__verbose"] = result; @@ -119,7 +119,7 @@ inline static json format_final_response_oaicompat(const json &request, const ta } // return value is vector as there is one case where we might need to generate two responses -inline static std::vector format_partial_response_oaicompat(const task_result &response) { +inline static std::vector format_partial_response_oaicompat(const task_result &response, std::string id) { json result = response.result_json; if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { @@ -165,7 +165,7 @@ inline static std::vector format_partial_response_oaicompat(const task_res {"role", "assistant"} }}}})}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", id}, {"model", modelname}, {"object", "chat.completion.chunk"}}; @@ -176,7 +176,7 @@ inline static std::vector format_partial_response_oaicompat(const task_res {"content", content}}} }})}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", id}, {"model", modelname}, {"object", "chat.completion.chunk"}}; @@ -202,7 +202,7 @@ inline static std::vector format_partial_response_oaicompat(const task_res json ret = json{{"choices", choices}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", id}, {"model", modelname}, {"object", "chat.completion.chunk"}}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8fe5e0b19..33d6faab6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3210,7 +3210,8 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); - const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) + const std::string completion_id = gen_chatcmplid(); + const auto chat_completions = [&llama, &validate_api_key, &sparams, &completion_id](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { @@ -3227,7 +3228,7 @@ int main(int argc, char **argv) task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { - json oaicompat_result = format_final_response_oaicompat(data, result); + json oaicompat_result = format_final_response_oaicompat(data, result, completion_id); res.set_content(oaicompat_result.dump(-1, ' ', false, json::error_handler_t::replace), @@ -3238,11 +3239,11 @@ int main(int argc, char **argv) } llama.queue_results.remove_waiting_task_id(task_id); } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { + const auto chunked_content_provider = [task_id, &llama, &completion_id](size_t, httplib::DataSink &sink) { while (true) { task_result llama_result = llama.queue_results.recv(task_id); if (!llama_result.error) { - std::vector result_array = format_partial_response_oaicompat( llama_result); + std::vector result_array = format_partial_response_oaicompat(llama_result, completion_id); for (auto it = result_array.begin(); it != result_array.end(); ++it) {