From 1c486169ed829ba10c625a5b162424f68608fd0a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 12:11:00 +0200 Subject: [PATCH] adapt all examples --- common/common.cpp | 4 +- examples/batched-bench/batched-bench.cpp | 1 - .../cvector-generator/cvector-generator.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/imatrix/imatrix.cpp | 13 ++++++- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 3 -- examples/llava/llava-cli.cpp | 2 +- examples/llava/llava.cpp | 38 ++++++++++++++++++- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 4 +- examples/lookup/lookup.cpp | 4 +- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 1 - examples/perplexity/perplexity.cpp | 27 ++++++++++--- examples/save-load-state/save-load-state.cpp | 8 ++-- examples/server/server.cpp | 1 - examples/speculative/speculative.cpp | 6 +-- src/llama.cpp | 1 + 20 files changed, 92 insertions(+), 37 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index a0611f3d1..253836129 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -912,7 +912,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } if (llama_model_has_encoder(model)) { - llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == -1) { decoder_start_token_id = bos; @@ -921,7 +921,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 4a15941f1..921e5dc3b 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -74,7 +74,6 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 41bf4eb2a..102e3c517 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_cache_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 6d629fe4e..0ab2f0d0e 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const gpt_params & params) { std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index c8e273529..f743d2269 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -508,12 +508,21 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use batch.logits to save computations instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + + if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; } + llama_batch_free(batch); + // restore the original token in case it was set to BOS tokens[batch_start] = token_org; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d52425ae6..67331eb70 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -396,7 +396,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index fb1d387b2..d989140a6 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1446,7 +1446,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); + llama_decode(ctx, llama_batch_get_one(&token, 1)); llama_synchronize(ctx); token = std::rand() % n_vocab; } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index f611809c6..dd9f5a8db 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, nullptr, nullptr, nullptr, - 0, - 0, - 0, }; if (embd) { diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8f437863f..509927360 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 8558c6bdc..aa94ff22c 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } +struct llava_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_n_embd(llama_get_model(ctx_llama)); @@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; - if (llama_decode(ctx_llama, batch)) { + float * embd = image_embed->embed+i*n_embd; + llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); + if (llama_decode(ctx_llama, llava_batch.batch)) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index c5156c35b..b31832bb0 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 49870b4a4..e10a27691 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -89,8 +89,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2ccd0e6c1..db980e9d0 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -89,8 +89,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); const auto t_enc_end = ggml_time_us(); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6bbb1e13e..5801b4fe2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -528,7 +528,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -648,7 +648,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 81e2f7ed7..6d3ad3d23 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -308,7 +308,6 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 87347135e..10c982b84 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -409,13 +409,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, batch)) { //LOG_ERR("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } + llama_batch_free(batch); + // save original token and restore it after eval const auto token_org = tokens[batch_start]; @@ -699,7 +708,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); @@ -1790,12 +1798,21 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; i++) { + batch. token[i] = tokens[batch_start + i]; + batch. pos[i] = j*n_batch + i; + batch.logits[i] = true; + batch.seq_id[i][0] = 0; + } + + if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return; } + llama_batch_free(batch); + // restore the original token in case it was set to BOS tokens[batch_start] = token_org; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 0117d9357..72f94beb8 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -49,7 +49,7 @@ int main(int argc, char ** argv) { auto tokens = llama_tokenize(ctx, params.prompt, true); // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); n_past += tokens.size(); // save state (rng, logits, embedding and kv_cache) to file @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { + if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -221,7 +221,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f343cc252..163b2ac89 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2283,7 +2283,6 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index adf6255e1..bc67aae39 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -155,9 +155,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); const auto t_enc_end = ggml_time_us(); diff --git a/src/llama.cpp b/src/llama.cpp index 825999b3c..4c83b9af7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21144,6 +21144,7 @@ struct llama_batch_allocr { logits[logits.size() - 1] = true; batch.logits = logits.data(); } + return batch; } };