From 0d6485f0f830d9fd3de5680e861f897d6e9312aa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 4 Dec 2024 15:03:37 +0100 Subject: [PATCH] wip [no ci] --- examples/server/server.cpp | 26 +++++++++----- examples/server/server.hpp | 2 ++ examples/server/utils.hpp | 71 +++++++++++++++++--------------------- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index de073b085..a673fb415 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -494,7 +494,9 @@ struct server_response { } // Send a new result to a waiting id_task - void send(server_task_result & result) { + template + void send(T & result) { + static_assert(std::is_base_of::value, "T must be derived from server_task_result"); SRV_DBG("sending result for task id = %d\n", result.id); std::unique_lock lock(mutex_results); @@ -502,7 +504,7 @@ struct server_response { if (result.id == id_task) { SRV_DBG("task id = %d pushed to result queue\n", result.id); - queue_results.push_back(std::make_unique(result)); + queue_results.push_back(std::make_unique(std::move(result))); condition_results.notify_all(); return; } @@ -1166,8 +1168,10 @@ struct server_context { void send_partial_response(server_slot & slot, completion_token_output tkn) { server_task_result_cmpl_partial res; - res.id = slot.id_task; - res.content = tkn.text_to_send; + res.id = slot.id_task; + res.n_decoded = slot.n_decoded; + res.n_prompt_tokens = slot.n_prompt_tokens; + res.content = tkn.text_to_send; if (slot.params.sampling.n_probs > 0) { const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); @@ -1189,7 +1193,11 @@ struct server_context { queue_results.send(res); } - void send_final_response(const server_slot & slot) { + void send_final_response(server_slot & slot) { + if (slot.params.stream) { + return send_partial_response(slot, {0, "", {}}); + } + server_task_result_cmpl_final res; res.id = slot.id_task; res.id_slot = slot.id; @@ -1380,6 +1388,7 @@ struct server_context { const std::unordered_set & id_tasks, const std::function&)> & result_handler, const std::function & error_handler) { + static_assert(std::is_base_of::value, "T must be derived from server_task_result"); std::vector results(id_tasks.size()); for (size_t i = 0; i < id_tasks.size(); i++) { task_result_ptr result_raw = queue_results.recv(id_tasks); @@ -2815,7 +2824,7 @@ int main(int argc, char ** argv) { if (!stream) { ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose); + json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose); res_ok(res, result_oai); }, [&](const json & error_data) { res_error(res, error_data); @@ -2823,9 +2832,10 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { + std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.to_json(), completion_id); + std::vector result_array = format_partial_response_oaicompat(model_name, result, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { continue; // skip the stop token diff --git a/examples/server/server.hpp b/examples/server/server.hpp index 081ad2069..6197ae565 100644 --- a/examples/server/server.hpp +++ b/examples/server/server.hpp @@ -281,6 +281,8 @@ struct server_task_result_cmpl_partial : server_task_result { server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {} int index = 0; std::string content; + int32_t n_decoded; + int32_t n_prompt_tokens; stop_type stop = STOP_TYPE_NONE; std::vector probs_output; result_timings timings; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b01a7757f..98a777192 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -583,15 +583,14 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - +static json format_final_response_oaicompat( + const json & request, + server_task_result_cmpl_final & result, + const std::string & completion_id, + bool streaming = false, + bool verbose = false) { std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { + if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) { finish_reason = "stop"; } @@ -601,7 +600,7 @@ static json format_final_response_oaicompat(const json & request, const json & r {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", content}, + {"message", json{{"content", result.content}, {"role", "assistant"}}}}}); std::time_t t = std::time(0); @@ -613,48 +612,42 @@ static json format_final_response_oaicompat(const json & request, const json & r json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, {"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} + {"completion_tokens", result.n_decoded}, + {"prompt_tokens", result.n_prompt_tokens}, + {"total_tokens", result.n_decoded + result.n_prompt_tokens} }}, {"id", completion_id} }; // extra fields for debugging purposes if (verbose) { - res["__verbose"] = result; + res["__verbose"] = result.to_json(); } - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } + // TODO: fix this + // if (result.contains("completion_probabilities")) { + // res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + // } - if (result.contains("timings")) { - res.push_back({"timings", json_value(result, "timings", json::object())}); + if (result.timings.prompt_n >= 0) { + res.push_back({"timings", result.timings.to_json()}); } return res; } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id) { - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({result}); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); +static std::vector format_partial_response_oaicompat( + std::string modelname, + server_task_result_cmpl_partial & result, + const std::string & completion_id) { + bool first = result.n_decoded == 0; + std::string content = result.content; std::string finish_reason; - if (stopped_word || stopped_eos) { + if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) { finish_reason = "stop"; - } - if (stopped_limit) { + } else if (result.stop == STOP_TYPE_LIMIT) { finish_reason = "length"; } @@ -724,17 +717,15 @@ static std::vector format_partial_response_oaicompat(const json & result, {"object", "chat.completion.chunk"} }; - if (result.contains("timings")) { - ret.push_back({"timings", json_value(result, "timings", json::object())}); + if (result.timings.prompt_n >= 0) { + ret.push_back({"timings", result.timings.to_json()}); } if (!finish_reason.empty()) { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); ret.push_back({"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} + {"completion_tokens", result.n_decoded}, + {"prompt_tokens", result.n_prompt_tokens}, + {"total_tokens", result.n_decoded + result.n_prompt_tokens} }}); }