From a5603ded45938d1b3f6e0029dcc978eca90eeae4 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 25 Feb 2024 15:44:03 +0100 Subject: [PATCH] move llama_client_slot to utils.hpp --- examples/server/server.cpp | 298 +------------- .../server/tests/features/parallel.feature | 2 +- .../server/tests/features/security.feature | 2 +- examples/server/tests/features/server.feature | 2 +- .../tests/features/wrong_usages.feature | 2 +- examples/server/utils.hpp | 363 ++++++++++++++++-- 6 files changed, 344 insertions(+), 325 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6136c3903..2a8673d06 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -144,238 +144,6 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector cache_tokens; - std::vector generated_token_probs; - - bool infill = false; - bool embedding = false; - bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; - - bool oaicompat = false; - std::string oaicompat_model; - - std::string stopping_word; - - // sampling - struct llama_sampling_params sparams; - llama_sampling_context *ctx_sampling = nullptr; - - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width - - int32_t n_past_se = 0; // self-extend - - // multimodal - std::vector images; - - // stats - size_t sent_count = 0; - size_t sent_token_probs_index = 0; - - int64_t t_start_process_prompt; - int64_t t_start_genereration; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - // multitasks - int multitask_id = -1; - - void reset() { - num_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - sent_count = 0; - sent_token_probs_index = 0; - infill = false; - ga_i = 0; - n_past_se = 0; - - generated_token_probs.clear(); - - for (slot_image & img : images) - { - free(img.image_embedding); - if (img.img_data) { - clip_image_u8_free(img.img_data); - } - img.prefix_prompt = ""; - } - - images.clear(); - } - - bool has_budget(gpt_params &global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) - { - return true; // limitless - } - - n_remaining = -1; - - if (params.n_predict != -1) - { - n_remaining = params.n_predict - n_decoded; - } - else if (global_params.n_predict != -1) - { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool available() const { - return state == IDLE && command == NONE; - } - - bool is_processing() const { - return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; - } - - void add_token_string(const completion_token_output &token) { - if (command == RELEASE) - { - return; - } - cache_tokens.push_back(token.tok); - generated_token_probs.push_back(token); - } - - void release() { - if (state == PROCESSING) - { - t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; - command = RELEASE; - } - } - - json get_formated_timings() { - return json - { - {"prompt_n", num_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, - - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; - } - - void print_timings() const { - char buffer[512]; - double t_token = t_prompt_processing / num_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed; - sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, num_prompt_tokens_processed, - t_token, n_tokens_second); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_prompt_processing", t_prompt_processing}, - {"num_prompt_tokens_processed", num_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, - t_token, n_tokens_second); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); - } -}; - -struct llama_metrics { - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t n_tokens_predicted_total = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - - void on_prompt_eval(const llama_client_slot &slot) { - n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed; - - n_prompt_tokens_processed += slot.num_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - } - - void on_prediction(const llama_client_slot &slot) { - n_tokens_predicted_total += slot.n_decoded; - - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - struct llama_server_context { llama_model *model = nullptr; @@ -1795,21 +1563,8 @@ struct llama_server_context if (slot.ga_n != 1) { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; - } - slot.n_past_se = slot_npast; - slot.ga_i = ga_i; + // context extension via Self-Extend + slot.grp_attn_update_params(); } LOG_INFO("slot progression", { @@ -1855,22 +1610,16 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; + int32_t slot_npast = slot.n_past; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { if (slot.ga_n != 1) { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } + // context extension via Self-Extend + slot_npast = slot.grp_attn_calc_npast(); } + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); slot_npast++; } @@ -1902,6 +1651,7 @@ struct llama_server_context all_slots_are_idle = true; } + // loop of n_batch for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); @@ -1911,28 +1661,9 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); - } - slot.n_past_se += n_tokens; + // TODO @ngxson: What happen if we're retrying with smaller n_batch? + // By the second time we retry, "grp_attn_shift" has already been called + slot.grp_attn_shift(ctx, n_tokens); } } @@ -1962,7 +1693,7 @@ struct llama_server_context slot.release(); } has_next_response = false; - break; + break; // break loop of n_batch } LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); @@ -1970,14 +1701,15 @@ struct llama_server_context // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; - continue; + continue; // continue loop of n_batch } + // loop of slots for (auto & slot : slots) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; + continue; // continue loop of slots } // prompt evaluated for embedding @@ -1986,7 +1718,7 @@ struct llama_server_context send_embedding(slot); slot.release(); slot.i_batch = -1; - continue; + continue; // continue loop of slots } completion_token_output result; diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index c85f9de1d..6fe1e05de 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -2,7 +2,7 @@ Feature: Parallel Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 0.0.0.0:8080 And a model file stories260K.gguf And a model alias tinyllama-2 And 42 as server seed diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index db06d3977..dba0849d1 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -2,7 +2,7 @@ Feature: Security Background: Server startup with an api key defined - Given a server listening on localhost:8080 + Given a server listening on 0.0.0.0:8080 And a model file stories260K.gguf And a server api key llama.cpp Then the server is starting diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index b571582a7..10941972e 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -2,7 +2,7 @@ Feature: llama.cpp server Background: Server startup - Given a server listening on localhost:8080 + Given a server listening on 0.0.0.0:8080 And a model file stories260K.gguf And a model alias tinyllama-2 And 42 as server seed diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index e228b2371..f4fc6a8a2 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -6,7 +6,7 @@ Feature: Wrong usage of llama.cpp server # to cap the number of tokens any completion request can generate # or pass n_predict/max_tokens in the request. Scenario: Infinite loop - Given a server listening on localhost:8080 + Given a server listening on 0.0.0.0:8080 And a model file stories260K.gguf # Uncomment below to fix the issue #And 64 server max tokens to predict diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index bfe00d3a1..a5a6ab9f8 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -37,9 +37,49 @@ extern bool server_log_json; #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -// -// parallel -// +static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) +{ + std::stringstream ss_tid; + ss_tid << std::this_thread::get_id(); + json log = nlohmann::ordered_json{ + {"tid", ss_tid.str()}, + {"timestamp", time(nullptr)}, + }; + + if (server_log_json) { + log.merge_patch( + { + {"level", level}, + {"function", function}, + {"line", line}, + {"msg", message}, + }); + if (!extra.empty()) { + log.merge_patch(extra); + } + + std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush; + } else { + char buf[1024]; + snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); + + if (!extra.empty()) { + log.merge_patch(extra); + } + std::stringstream ss; + ss << buf << " |"; + for (const auto& el : log.items()) + { + const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); + snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str()); + ss << buf; + } + + const std::string str = ss.str(); + printf("%.*s\n", (int)str.size(), str.data()); + fflush(stdout); + } +} enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet @@ -134,49 +174,296 @@ struct completion_token_output std::string text_to_send; }; -static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) +struct llama_client_slot { - std::stringstream ss_tid; - ss_tid << std::this_thread::get_id(); - json log = nlohmann::ordered_json{ - {"tid", ss_tid.str()}, - {"timestamp", time(nullptr)}, - }; + int id; + int task_id = -1; - if (server_log_json) { - log.merge_patch( - { - {"level", level}, - {"function", function}, - {"line", line}, - {"msg", message}, - }); - if (!extra.empty()) { - log.merge_patch(extra); - } + struct slot_params params; - std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush; - } else { - char buf[1024]; - snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); + slot_state state = IDLE; + slot_command command = NONE; - if (!extra.empty()) { - log.merge_patch(extra); - } - std::stringstream ss; - ss << buf << " |"; - for (const auto& el : log.items()) + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + int32_t n_predict = -1; + + int32_t num_prompt_tokens = 0; + int32_t num_prompt_tokens_processed = 0; + + json prompt; + std::string generated_text; + llama_token sampled; + std::vector cache_tokens; + std::vector generated_token_probs; + + bool infill = false; + bool embedding = false; + bool has_next_token = true; + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; + + bool oaicompat = false; + std::string oaicompat_model; + + std::string stopping_word; + + // sampling + struct llama_sampling_params sparams; + llama_sampling_context *ctx_sampling = nullptr; + + int32_t ga_i = 0; // group-attention state + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width + + int32_t n_past_se = 0; // self-extend + + // multimodal + std::vector images; + + // stats + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + + int64_t t_start_process_prompt; + int64_t t_start_genereration; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + // multitasks + int multitask_id = -1; + + void reset() { + num_prompt_tokens = 0; + generated_text = ""; + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + n_past = 0; + sent_count = 0; + sent_token_probs_index = 0; + infill = false; + ga_i = 0; + n_past_se = 0; + + generated_token_probs.clear(); + + for (slot_image & img : images) { - const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str()); - ss << buf; + free(img.image_embedding); + if (img.img_data) { + clip_image_u8_free(img.img_data); + } + img.prefix_prompt = ""; } - const std::string str = ss.str(); - printf("%.*s\n", (int)str.size(), str.data()); - fflush(stdout); + images.clear(); } -} + + bool has_budget(gpt_params &global_params) { + if (params.n_predict == -1 && global_params.n_predict == -1) + { + return true; // limitless + } + + n_remaining = -1; + + if (params.n_predict != -1) + { + n_remaining = params.n_predict - n_decoded; + } + else if (global_params.n_predict != -1) + { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool available() const { + return state == IDLE && command == NONE; + } + + bool is_processing() const { + return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + } + + void add_token_string(const completion_token_output &token) { + if (command == RELEASE) + { + return; + } + cache_tokens.push_back(token.tok); + generated_token_probs.push_back(token); + } + + void release() { + if (state == PROCESSING) + { + t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; + command = RELEASE; + } + } + + json get_formated_timings() { + return json + { + {"prompt_n", num_prompt_tokens_processed}, + {"prompt_ms", t_prompt_processing}, + {"prompt_per_token_ms", t_prompt_processing / num_prompt_tokens_processed}, + {"prompt_per_second", 1e3 / t_prompt_processing * num_prompt_tokens_processed}, + + {"predicted_n", n_decoded}, + {"predicted_ms", t_token_generation}, + {"predicted_per_token_ms", t_token_generation / n_decoded}, + {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, + }; + } + + void print_timings() const { + char buffer[512]; + double t_token = t_prompt_processing / num_prompt_tokens_processed; + double n_tokens_second = 1e3 / t_prompt_processing * num_prompt_tokens_processed; + sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", + t_prompt_processing, num_prompt_tokens_processed, + t_token, n_tokens_second); + LOG_INFO(buffer, { + {"slot_id", id}, + {"task_id", task_id}, + {"t_prompt_processing", t_prompt_processing}, + {"num_prompt_tokens_processed", num_prompt_tokens_processed}, + {"t_token", t_token}, + {"n_tokens_second", n_tokens_second}, + }); + + t_token = t_token_generation / n_decoded; + n_tokens_second = 1e3 / t_token_generation * n_decoded; + sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", + t_token_generation, n_decoded, + t_token, n_tokens_second); + LOG_INFO(buffer, { + {"slot_id", id}, + {"task_id", task_id}, + {"t_token_generation", t_token_generation}, + {"n_decoded", n_decoded}, + {"t_token", t_token}, + {"n_tokens_second", n_tokens_second}, + }); + + sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation); + LOG_INFO(buffer, { + {"slot_id", id}, + {"task_id", task_id}, + {"t_prompt_processing", t_prompt_processing}, + {"t_token_generation", t_token_generation}, + {"t_total", t_prompt_processing + t_token_generation}, + }); + } + + // context extension via Self-Extend + void grp_attn_update_params() { + int grpa_i = 0; + // copy to local variables + int32_t grpa_n = ga_n; + int32_t grpa_w = ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < n_past; ++k) + { + while (slot_npast >= grpa_i + grpa_w) { + const int bd = (grpa_w/grpa_n)*(grpa_n - 1); + slot_npast -= bd; + grpa_i += grpa_w/grpa_n; + } + slot_npast++; + } + n_past_se = slot_npast; + ga_i = grpa_i; + } + + int32_t grp_attn_calc_npast() { + int32_t slot_npast = n_past_se > 0 ? n_past_se : n_past; + // copy to local variables + int32_t grpa_i = ga_i; + int32_t grpa_n = ga_n; + int32_t grpa_w = ga_w; + while (slot_npast >= grpa_i + grpa_w) { + const int bd = (grpa_w/grpa_n)*(grpa_n - 1); + slot_npast -= bd; + grpa_i += grpa_w/grpa_n; + } + return slot_npast; + } + + void grp_attn_shift(llama_context * ctx, const int32_t n_tokens) { + while (n_past_se >= ga_i + ga_w) + { + const int ib = (ga_n * ga_i) / ga_w; + const int bd = (ga_w / ga_n) * (ga_n - 1); + const int dd = (ga_w / ga_n) - ib * bd - ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past_se, ib * bd, ga_i + ib * bd, n_past_se + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib * bd, ga_i + ib * bd + ga_w, ga_n, (ga_i + ib * bd) / ga_n, (ga_i + ib * bd + ga_w) / ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib * bd + ga_w, n_past_se + ib * bd, dd, ga_i + ib * bd + ga_w + dd, n_past_se + ib * bd + dd); + + llama_kv_cache_seq_shift(ctx, id, ga_i, n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, id, ga_i + ib * bd, ga_i + ib * bd + ga_w,ga_n); + llama_kv_cache_seq_shift(ctx, id, ga_i + ib * bd + ga_w,n_past_se + ib * bd, dd); + + n_past_se -= bd; + + ga_i += ga_w / ga_n; + + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past_se + bd, n_past_se, ga_i); + } + n_past_se += n_tokens; + } +}; + +struct llama_metrics { + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t n_tokens_predicted_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + + void on_prompt_eval(const llama_client_slot &slot) { + n_prompt_tokens_processed_total += slot.num_prompt_tokens_processed; + + n_prompt_tokens_processed += slot.num_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + } + + void on_prediction(const llama_client_slot &slot) { + n_tokens_predicted_total += slot.n_decoded; + + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; // // server utils