From 291a7855873c89be709868fbae1d076bc37bfde3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 22 Oct 2024 15:03:00 +0200 Subject: [PATCH 1/2] llama : rename batch.logits to batch.output This commit renames the `logits` field of the `llama_batch` struct to `output`. The motivation for this change (apart from the TODO comment) is that the `logits` field is actually used to specify that output should be generated. For example, in the case of generating embeddings, setting logits to true can be confusing since the logits are not used when generating embeddings. --- common/common.cpp | 6 +++--- examples/batched-bench/batched-bench.cpp | 4 ++-- examples/batched.swift/Sources/main.swift | 6 +++--- examples/batched/batched.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- .../llama/src/main/cpp/llama-android.cpp | 6 +++--- .../llama.cpp.swift/LibLlama.swift | 8 ++++---- examples/llava/llava.cpp | 8 ++++---- examples/parallel/parallel.cpp | 4 ++-- examples/passkey/passkey.cpp | 4 ++-- examples/perplexity/perplexity.cpp | 14 +++++++------- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 8 ++++---- examples/tts/tts.cpp | 2 +- include/llama.h | 2 +- src/llama-batch.cpp | 18 +++++++++--------- src/llama-batch.h | 2 +- src/llama.cpp | 5 ++--- 19 files changed, 52 insertions(+), 53 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8661e164a..859e726af 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat << ", pos " << std::to_string(batch.pos[i]) << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) << ", seq_id " << std::to_string(batch.seq_id[i][0]) - << ", logits " << std::to_string(batch.logits[i]); + << ", output " << std::to_string(batch.output[i]); } buf << " ]"; @@ -1617,7 +1617,7 @@ void common_batch_add( llama_token id, llama_pos pos, const std::vector & seq_ids, - bool logits) { + bool output) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); batch.token [batch.n_tokens] = id; @@ -1626,7 +1626,7 @@ void common_batch_add( for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } - batch.logits [batch.n_tokens] = logits; + batch.output [batch.n_tokens] = output; batch.n_tokens++; } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f1..1f1c95627 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -128,7 +128,7 @@ int main(int argc, char ** argv) { common_batch_add(batch, 0, i, { j }, false); } } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 55c31166c..18b6a21d8 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() { if let seq_id = batch.seq_id[i] { seq_id[0] = 0 } - batch.logits[i] = 0 + batch.output[i] = 0 } // llama_decode will output logits only for the last token of the prompt -batch.logits[Int(batch.n_tokens) - 1] = 1 +batch.output[Int(batch.n_tokens) - 1] = 1 if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -171,7 +171,7 @@ while n_cur <= n_len { if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) } - batch.logits[Int(batch.n_tokens)] = 1 + batch.output[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e..7d2a82b51 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38d22c90f..95445b5ef 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2a73983a9..1718d6b4f 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( common_batch_add(*batch, 0, i, { 0 }, false); } - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; if (llama_decode(context, *batch) != 0) { LOGe("llama_decode() failed"); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index ee7141a66..dfece7761 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) { batch.n_tokens = 0 } -func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { +func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) { batch.token [Int(batch.n_tokens)] = id batch.pos [Int(batch.n_tokens)] = pos batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) for i in 0.. n_seq_id; std::vector seq_id_0; std::vector seq_ids; - std::vector logits; + std::vector outputs; llama_batch batch; llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { pos .resize(n_tokens); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); + outputs .resize(n_tokens); seq_id_0.resize(1); seq_id_0[0] = seq_id; seq_ids [n_tokens] = nullptr; @@ -458,13 +458,13 @@ struct llava_embd_batch { /*pos =*/ pos.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), + /*output =*/ outputs.data(), }; for (int i = 0; i < n_tokens; i++) { batch.pos [i] = pos_0 + i; batch.n_seq_id[i] = 1; batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; + batch.output [i] = false; } } }; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ef43d5e1..3f87c0a1a 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { // extract the logits only for the last token if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } client.n_prompt = tokens_prompt.size(); @@ -309,7 +309,7 @@ int main(int argc, char ** argv) { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 5953928d4..15f99bcdd 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -146,7 +146,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; } if (llama_decode(ctx, batch) != 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9bf6c5743..2b194b8d9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & batch.pos [idx] = j*n_batch + k; batch.n_seq_id[idx] = 1; batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + batch.output [idx] = batch.pos[idx] >= first ? 1 : 0; - n_outputs += batch.logits[idx] != 0; + n_outputs += batch.output[idx] != 0; } batch.n_tokens += batch_size; @@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); @@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< int n_outputs = 0; for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; + n_outputs += batch_view.output[i] != 0; } memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); @@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (size_t i = 0; i < data[i1].common_prefix; ++i) { common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; n_logits += 1; for (int s = 0; s < 2; ++s) { @@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2439022a2..2c5b5e486 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + if (!batch.output[i]) { continue; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cf7cbd815..2e5a2b518 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokens.size(); i++) { common_batch_add(batch, tokens[i], i, {0}, false); } - batch.logits[batch.n_tokens - 1] = true; // generate next token + batch.output[batch.n_tokens - 1] = true; // generate next token // evaluate prompt llama_decode(ctx, batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9cdf2058f..f6642e5c8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2413,7 +2413,7 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -2451,7 +2451,7 @@ struct server_context { res->n_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.output[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -3109,7 +3109,7 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; @@ -3149,7 +3149,7 @@ struct server_context { batch.pos + i, batch.n_seq_id + i, batch.seq_id + i, - batch.logits + i, + batch.output + i, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index f78f76303..f70022985 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + batch.output[batch.n_tokens - 1] = true; if (llama_decode(ctx_ttc, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 61907ed40..516953a72 100644 --- a/include/llama.h +++ b/include/llama.h @@ -252,7 +252,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int8_t * output; } llama_batch; enum llama_model_kv_override_type { diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57f..ba2127be6 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s ubatch.output[ubatch.n_tokens + i] = 1; out_ids.push_back(ids[seq.offset + i]); } - } else if (batch->logits) { + } else if (batch->output) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; - int8_t is_output = batch->logits[id]; + int8_t is_output = batch->output[id]; ubatch.output[ubatch.n_tokens + i] = is_output; if (is_output) { out_ids.push_back(id); } } } else { // simple split - ubatch.output = batch->logits + seq.offset; + ubatch.output = batch->output + seq.offset; for (size_t i = 0; i < length; ++i) { if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } } @@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 } batch.seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); + if (!batch.output) { + outputs.resize(batch.n_tokens); + outputs[outputs.size() - 1] = true; + batch.output = outputs.data(); } } @@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } @@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) { } free(batch.seq_id); } - if (batch.logits) free(batch.logits); + if (batch.output) free(batch.output); } diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b..002a8a62f 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -81,7 +81,7 @@ struct llama_batch_allocr { std::vector pos; std::vector n_seq_id; std::vector seq_id; - std::vector logits; + std::vector outputs; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); diff --git a/src/llama.cpp b/src/llama.cpp index aae3c69b5..e24c39465 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8473,9 +8473,9 @@ static int llama_prepare_sbatch( lctx.embd_seq.clear(); // count outputs - if (batch.logits && !embd_pooled) { + if (batch.output && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch.logits[i] != 0; + n_outputs += batch.output[i] != 0; } } else if (lctx.logits_all || embd_pooled) { n_outputs = n_tokens_all; @@ -9972,7 +9972,6 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { return llama_kv_cache_can_shift(ctx->kv_self); } -/// int32_t llama_encode( struct llama_context * ctx, From 27f59dbaaa33bb0941a533ac4fd257e0cc9c564b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 6 Feb 2025 08:00:30 +0100 Subject: [PATCH 2/2] squash! llama : rename batch.logits to batch.output Fix incorrectly named field in LibLlama.swift, outputs -> output. --- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index dfece7761..7b4a55f2f 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama for i in 0..