From 21f8b73d606812efce141aeb0f004112c0e6025d Mon Sep 17 00:00:00 2001 From: lhpqaq <657407891@qq.com> Date: Sun, 1 Dec 2024 00:35:41 +0800 Subject: [PATCH] fix code --- examples/server/server.cpp | 18 ++++++++++++------ examples/server/utils.hpp | 8 ++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a18489426..aaad0de5c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -177,6 +177,8 @@ struct server_slot { bool stopped_word = false; bool stopped_limit = false; + bool timing_per_token = false; + bool oaicompat = false; std::string oaicompat_model; @@ -882,6 +884,8 @@ struct server_context { slot.oaicompat_model = ""; } + slot.timing_per_token = json_value(data, "timing_per_token", false); + slot.params.stream = json_value(data, "stream", false); slot.params.cache_prompt = json_value(data, "cache_prompt", true); slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); @@ -1269,7 +1273,6 @@ struct server_context { {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, {"ignore_eos", slot.params.sampling.ignore_eos}, - {"timing_per_token", slot.params.sampling.timing_per_token}, {"stream", slot.params.stream}, //{"logit_bias", slot.params.sampling.logit_bias}, {"n_probs", slot.params.sampling.n_probs}, @@ -1280,6 +1283,7 @@ struct server_context { {"speculative.n_max", slot.params.speculative.n_max}, {"speculative.n_min", slot.params.speculative.n_min}, {"speculative.p_min", slot.params.speculative.p_min}, + {"timing_per_token", slot.timing_per_token}, }; } @@ -1314,7 +1318,6 @@ struct server_context { {"id_slot", slot.id}, {"multimodal", false}, {"index", slot.index}, - {"timings", slot.get_formated_timings()}, }; if (slot.params.sampling.n_probs > 0) { @@ -1338,6 +1341,10 @@ struct server_context { res.data["model"] = slot.oaicompat_model; } + if (slot.timing_per_token) { + res.data["timings"] = slot.get_formated_timings(); + } + queue_results.send(res); } @@ -3002,7 +3009,6 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(tasks); bool stream = json_value(data, "stream", false); - bool timings = json_value(data, "timing_per_token", false); const auto task_ids = server_task::get_list_id(tasks); const auto completion_id = gen_chatcmplid(); @@ -3010,7 +3016,7 @@ int main(int argc, char ** argv) { if (!stream) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings); + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); res_ok(res, result_oai); }, [&](const json & error_data) { res_error(res, error_data); @@ -3018,9 +3024,9 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id, timings); + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { continue; // skip the stop token diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 072811367..e4451532c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) { +static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); @@ -650,7 +650,7 @@ static json format_final_response_oaicompat(const json & request, const json & r res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); } - if (timings) { + if (result.contains("timings")) { res.push_back({"timings", json_value(result, "timings", json::object())}); } @@ -658,7 +658,7 @@ static json format_final_response_oaicompat(const json & request, const json & r } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) { +static std::vector format_partial_response_oaicompat(const json & result, const std::string & completion_id) { if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { return std::vector({result}); } @@ -745,7 +745,7 @@ static std::vector format_partial_response_oaicompat(const json & result, {"object", "chat.completion.chunk"} }; - if (timings) { + if (result.contains("timings")) { ret.push_back({"timings", json_value(result, "timings", json::object())}); }