add timings

This commit is contained in:
lhpqaq 2024-11-29 10:20:58 +08:00
parent 2c96bd2466
commit fb10521514
2 changed files with 23 additions and 6 deletions

View file

@ -1313,6 +1313,7 @@ struct server_context {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"multimodal", false}, {"multimodal", false},
{"index", slot.index}, {"index", slot.index},
{"timings", slot.get_formated_timings()},
}; };
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -2274,12 +2275,17 @@ struct server_context {
common_sampler_accept(slot.smpl, id, true); common_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1; slot.n_decoded += 1;
const int64_t t_current = ggml_time_us();
if (slot.n_decoded == 1) { if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us(); slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot); metrics.on_prompt_eval(slot);
} }
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
completion_token_output result; completion_token_output result;
result.tok = id; result.tok = id;
@ -2995,13 +3001,15 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false); bool stream = json_value(data, "stream", false);
bool timings = json_value(data, "timing_per_token", false);
const auto task_ids = server_task::get_list_id(tasks); const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid(); const auto completion_id = gen_chatcmplid();
if (!stream) { if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result // multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings);
res_ok(res, result_oai); res_ok(res, result_oai);
}, [&](const json & error_data) { }, [&](const json & error_data) {
res_error(res, error_data); res_error(res, error_data);
@ -3009,9 +3017,9 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids); ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else { } else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id); std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings);
for (auto & event_data : result_array) { for (auto & event_data : result_array) {
if (event_data.empty()) { if (event_data.empty()) {
continue; // skip the stop token continue; // skip the stop token

View file

@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -650,11 +650,15 @@ static json format_final_response_oaicompat(const json & request, const json & r
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
} }
if (timings) {
res.push_back({"timings", json_value(result, "timings", json::object())});
}
return res; return res;
} }
// return value is vector as there is one case where we might need to generate two responses // return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) { static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result}); return std::vector<json>({result});
} }
@ -740,6 +744,11 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
{"model", modelname}, {"model", modelname},
{"object", "chat.completion.chunk"} {"object", "chat.completion.chunk"}
}; };
if (timings) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
}
if (!finish_reason.empty()) { if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);