From e264f2239ecb9ae726c2809e3f180f0b7509c2a1 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 14 Jan 2024 19:49:21 +0100 Subject: [PATCH] perplexity : ignore n_batch, submit whole chunk in one call --- examples/perplexity/perplexity.cpp | 87 ++++++++++++++---------------- llama.cpp | 5 +- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9a77beca6..d04fed9cd 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -189,19 +189,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int n_batch = params.n_batch; int count = 0; double nll = 0.0; - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks\n", __func__, n_chunk); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; - const int end = start + calc_chunk; - - const int num_batches = (calc_chunk + n_batch - 1) / n_batch; - //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); + //const int end = start + calc_chunk; std::vector logits; @@ -210,32 +206,26 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { - const int batch_start = start + j * n_batch; - const int batch_size = std::min(end - batch_start, n_batch); - //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { - //fprintf(stderr, "%s : failed to eval\n", __func__); - return {tokens, -1, logit_history, prob_history}; - } - - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; - - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); - } - - const auto batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); - - if (j == 0) { - tokens[batch_start] = token_org; - } + //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + start, calc_chunk, 0, 0))) { + //fprintf(stderr, "%s : failed to eval\n", __func__); + return {tokens, -1, logit_history, prob_history}; } + // save original token and restore it after eval + const auto token_org = tokens[start]; + + // add BOS token for the first batch of each chunk + if (add_bos) { + tokens[start] = llama_token_bos(llama_get_model(ctx)); + } + + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + calc_chunk * n_vocab); + + tokens[start] = token_org; + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { @@ -246,7 +236,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + fprintf(stderr, "%.2f minutes ", total_seconds / 60.0); + fprintf(stderr, "(%.2f t/s)\n", n_ctx/t_total); } //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); @@ -327,9 +318,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; - const int end = start + n_ctx; + //const int end = start + n_ctx; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; + //const int num_batches = (n_ctx + n_batch - 1) / n_batch; std::vector logits; @@ -338,33 +329,33 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { - const int batch_start = start + j * n_batch; - const int batch_size = std::min(end - batch_start, n_batch); + //for (int j = 0; j < num_batches; ++j) { + // const int batch_start = start + j * n_batch; + // const int batch_size = std::min(end - batch_start, n_batch); // save original token and restore it after eval - const auto token_org = tokens[batch_start]; + const auto token_org = tokens[start]; // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + if (add_bos) { + tokens[start] = llama_token_bos(llama_get_model(ctx)); } - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + start, n_ctx, 0, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; + tokens[start] = token_org; const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); - } + logits.insert(logits.end(), batch_logits, batch_logits + n_ctx * n_vocab); + //} const auto t_end = std::chrono::high_resolution_clock::now(); - if (i == 0) { + if (i == 1) { // TODO: skipping the first chunk gives a better estimate, but breaks formatting const float t_total = std::chrono::duration(t_end - t_start).count(); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); int total_seconds = (int)(t_total * n_chunk); @@ -372,7 +363,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + fprintf(stderr, "%.2f minutes ", total_seconds / 60.0); + fprintf(stderr, "(%.2f t/s)\n", n_ctx/t_total); + } // We get the logits for all the tokens in the context window (params.n_ctx) @@ -433,7 +426,7 @@ static std::vector hellaswag_evaluate_tokens( return {}; } - const auto logits = llama_get_logits(ctx); + const auto * logits = llama_get_logits(ctx); result.insert(result.end(), logits, logits + n_tokens * n_vocab); n_past += n_tokens; @@ -678,13 +671,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; - params.n_batch = 512; + //params.n_batch = 512; if (!gpt_params_parse(argc, argv, params)) { return 1; } params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); + //params.n_batch = std::min(params.n_batch, params.n_ctx); if (params.ppl_stride > 0) { fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", diff --git a/llama.cpp b/llama.cpp index 2061a6965..07aae6791 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6222,7 +6222,7 @@ static int llama_decode_internal( logits_valid.clear(); logits_valid.resize(n_tokens_all); - logits_out.clear(); + memset(logits_out, 0, lctx.logits_size*sizeof(float)); #endif @@ -6428,6 +6428,9 @@ static int llama_decode_internal( } } + //ggml_backend_sched_synchronize(lctx.sched); + //lctx.buf_cpu_ub_cur = 0; + // measure the performance only for the single-token evals if (n_tokens_all == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us;