Fix timing reporting and accumulation

This commit is contained in:
Georgi Gerganov 2023-03-22 07:17:42 +02:00
parent 4d2e035347
commit 71ed3d224d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 6 additions and 12 deletions

View file

@ -611,6 +611,8 @@ static bool llama_eval_internal(
const int n_tokens, const int n_tokens,
const int n_past, const int n_past,
const int n_threads) { const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const int N = n_tokens; const int N = n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
@ -834,6 +836,9 @@ static bool llama_eval_internal(
ggml_free(ctx0); ggml_free(ctx0);
lctx.t_eval_us += ggml_time_us() - t_start_us;
lctx.n_eval++;
return true; return true;
} }
@ -1509,6 +1514,7 @@ llama_token llama_sample_top_p_top_k(
repeat_penalty); repeat_penalty);
ctx->t_sample_us += ggml_time_us() - t_start_sample_us; ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
ctx->n_sample++;
return result; return result;
} }

View file

@ -156,7 +156,6 @@ void sigint_handler(int signo) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
gpt_params params; gpt_params params;
params.model = "models/llama-7B/ggml-model.bin"; params.model = "models/llama-7B/ggml-model.bin";
@ -226,9 +225,6 @@ int main(int argc, char ** argv) {
int n_past = 0; int n_past = 0;
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
// Add a space in front of the first character to match OG llama tokenizer behavior // Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' '); params.prompt.insert(0, 1, ' ');
@ -320,14 +316,10 @@ int main(int argc, char ** argv) {
while (remaining_tokens > 0 || params.interactive) { while (remaining_tokens > 0 || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }
t_predict_us += ggml_time_us() - t_start_us;
} }
n_past += embd.size(); n_past += embd.size();
@ -343,8 +335,6 @@ int main(int argc, char ** argv) {
llama_token id = 0; llama_token id = 0;
{ {
const int64_t t_start_sample_us = ggml_time_us();
auto logits = llama_get_logits(ctx); auto logits = llama_get_logits(ctx);
if (params.ignore_eos) { if (params.ignore_eos) {
@ -359,8 +349,6 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
t_sample_us += ggml_time_us() - t_start_sample_us;
} }
// add it to the context // add it to the context