From 27ecb076cf2fd1809f523b7f19699ea97b7ec53a Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Thu, 22 Aug 2024 11:28:16 +0800 Subject: [PATCH] simplify the code --- examples/perplexity/perplexity.cpp | 11 +---------- src/llama.cpp | 6 +++--- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 517f93e07..9fec60878 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -399,7 +399,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -407,16 +406,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & llama_batch batch = llama_batch_init(batch_size, 0, 1); for (int k = 0; k < batch_size; ++k) { const int idx = batch_start + k; - batch.token [k] = tokens[idx]; - batch.output [k] = 1; + llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true); } - batch.n_tokens = batch_size; - batch.pos = nullptr; - batch.n_seq_id = nullptr; - batch.seq_id = nullptr; - batch.all_pos_0 = j*n_batch; - batch.all_pos_1 = 1; - batch.all_seq_id = 0; //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); if (llama_decode(ctx, batch)) { diff --git a/src/llama.cpp b/src/llama.cpp index 7e798a07e..11fbd9313 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2874,17 +2874,17 @@ struct llama_sbatch { ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } - } else if (batch->logits) { + } else if (batch->output) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; + int8_t is_output = batch->output[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split - ubatch.output = batch->logits + seq.offset; + ubatch.output = batch->output + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } }