diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index 5fe1a62a8..50ca48f93 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -69,7 +69,7 @@ inline static json oaicompat_completion_params_parse( return llama_params; } -inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) +inline static json format_final_response_oaicompat(const json &request, const task_result &response, std::string id, bool streaming = false) { json result = response.result_json; @@ -105,7 +105,7 @@ inline static json format_final_response_oaicompat(const json &request, const ta json{{"completion_tokens", num_tokens_predicted}, {"prompt_tokens", num_prompt_tokens}, {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; + {"id", id}}; if (server_verbose) { res["__verbose"] = result; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f4a01397d..755bb76c5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3221,13 +3221,14 @@ int main(int argc, char **argv) const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); llama.request_completion(task_id, data, false, false, -1); + const std::string completion_id = gen_chatcmplid(); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { - json oaicompat_result = format_final_response_oaicompat(data, result); + json oaicompat_result = format_final_response_oaicompat(data, result, completion_id); res.set_content(oaicompat_result.dump(-1, ' ', false, json::error_handler_t::replace), @@ -3238,8 +3239,7 @@ int main(int argc, char **argv) } llama.queue_results.remove_waiting_task_id(task_id); } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { - const std::string completion_id = gen_chatcmplid(); + const auto chunked_content_provider = [task_id, &llama, completion_id](size_t, httplib::DataSink &sink) { while (true) { task_result llama_result = llama.queue_results.recv(task_id); if (!llama_result.error) {