From 71ed3d224d8a217cbf5ded5fb7ffdc6e254e844c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 22 Mar 2023 07:17:42 +0200 Subject: [PATCH] Fix timing reporting and accumulation --- llama.cpp | 6 ++++++ main.cpp | 12 ------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2ca572a6b..c9edb84f5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -611,6 +611,8 @@ static bool llama_eval_internal( const int n_tokens, const int n_past, const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + const int N = n_tokens; const auto & model = lctx.model; @@ -834,6 +836,9 @@ static bool llama_eval_internal( ggml_free(ctx0); + lctx.t_eval_us += ggml_time_us() - t_start_us; + lctx.n_eval++; + return true; } @@ -1509,6 +1514,7 @@ llama_token llama_sample_top_p_top_k( repeat_penalty); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; return result; } diff --git a/main.cpp b/main.cpp index b93c93b08..17c69b870 100644 --- a/main.cpp +++ b/main.cpp @@ -156,7 +156,6 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { ggml_time_init(); - const int64_t t_main_start_us = ggml_time_us(); gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -226,9 +225,6 @@ int main(int argc, char ** argv) { int n_past = 0; - int64_t t_sample_us = 0; - int64_t t_predict_us = 0; - // Add a space in front of the first character to match OG llama tokenizer behavior params.prompt.insert(0, 1, ' '); @@ -320,14 +316,10 @@ int main(int argc, char ** argv) { while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { - const int64_t t_start_us = ggml_time_us(); - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } - - t_predict_us += ggml_time_us() - t_start_us; } n_past += embd.size(); @@ -343,8 +335,6 @@ int main(int argc, char ** argv) { llama_token id = 0; { - const int64_t t_start_sample_us = ggml_time_us(); - auto logits = llama_get_logits(ctx); if (params.ignore_eos) { @@ -359,8 +349,6 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); - - t_sample_us += ggml_time_us() - t_start_sample_us; } // add it to the context