From 96811fdf637b52ee0885875d2d8eb3a51111563d Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Sat, 25 May 2024 14:23:57 -0400 Subject: [PATCH] duo: v2 --- examples/duo/duo.cpp | 77 +++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 48 deletions(-) diff --git a/examples/duo/duo.cpp b/examples/duo/duo.cpp index 9ba5e880c..e98b22893 100644 --- a/examples/duo/duo.cpp +++ b/examples/duo/duo.cpp @@ -88,6 +88,24 @@ static llama_tokens greedy_tokens(llama_model * model, llama_context * ctx, int3 return res; } +template +static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool all_logits, llama_batch & batch) +{ + llama_batch_clear(batch); + size_t i = offset; + for (auto it = from; it != to; ++it) + { + llama_batch_add(batch, *it, i++, { 0 }, all_logits); + } + batch.logits[batch.n_tokens - 1] = true; + int res = 0; + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "llama_decode() failed\n"); + res = 1; + } + return res; +} + static int speculation( llama_model * model, speculation_context * spec_ctx, @@ -96,20 +114,9 @@ static int speculation( // TODO: check that input is non-empty llama_batch batch = llama_batch_init(512, 0, 1); + decode(ctx, input.begin(), input.end(), 0, false, batch); - for (size_t i = 0; i < input.size(); i++) - { - llama_batch_add(batch, input[i], i, { 0 }, false); - } - - batch.logits[batch.n_tokens - 1] = true; - - if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return 1; - } - - int logit_idx = batch.n_tokens - 1; + int logit_idx = input.size() - 1; llama_tokens local = input; size_t match_len; @@ -167,20 +174,10 @@ static int speculation( } } - llama_batch_clear(batch); - // TODO theoretically this can be empty? - for (size_t i = match_len; i < local.size(); i++) - { - llama_batch_add(batch, local[i], i, { 0 }, true); - } - logit_idx = batch.n_tokens - 1; + decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch); - if (llama_decode(ctx, batch) != 0) - { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } + logit_idx = local.size() - match_len - 1; } llama_batch_free(batch); @@ -194,20 +191,11 @@ static int target( size_t n_predict) { dbg_default(to_string(ctx, input.begin(), input.end())); - // TODO: create int decode() + + llama_batch batch = llama_batch_init(512, 0, 1); - for (size_t i = 0; i < input.size(); i++) - { - llama_batch_add(batch, input[i], i, { 0 }, false); - } - batch.logits[batch.n_tokens - 1] = true; + decode(ctx, input.begin(), input.end(), 0, false, batch); - if (llama_decode(ctx, batch) != 0) { - fprintf(stderr, "llama_decode() failed\n"); - return 1; - } - - // how many tokens are currently accepted // TODO: rename to n_accepted size_t n_cur = input.size(); size_t n_decode = 0; @@ -215,8 +203,8 @@ static int target( const auto t_main_start = ggml_time_us(); // we'll use logits from this position to determine next token - int logits_from = batch.n_tokens - 1; - int logits_to = batch.n_tokens; + int logits_from = input.size() - 1; + int logits_to = input.size(); llama_tokens input_seq, next_tokens; input_seq.push_back(input.back()); @@ -299,15 +287,8 @@ static int target( break; } - llama_batch_clear(batch); - for (size_t i = 0; i < input_seq.size(); i++) - { - llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true); - } - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } + decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch); + logits_from = 0; logits_to = input_seq.size(); }