From dddd784c4de23af43d851f323f75c8c9bb874492 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Sep 2023 08:49:40 +0300 Subject: [PATCH] speculative : improve heuristic impl --- examples/speculative/speculative.cpp | 33 ++++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 4610be59f..51562bcb1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -116,7 +116,6 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; - bool all_accepted = false; while (true) { const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); @@ -143,9 +142,6 @@ int main(int argc, char ** argv) { ++n_past_dft; ++i_dft; - if (i_dft == (int) drafted.size()) { - all_accepted = true; - } continue; } @@ -153,20 +149,33 @@ int main(int argc, char ** argv) { llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; + // heuristic for n_draft + { + const int n_dradt_cur = (int) drafted.size(); + const bool all_accepted = i_dft == n_dradt_cur; + + LOG("n_draft = %d\n", n_draft); + LOG("n_draft_cur = %d\n", n_dradt_cur); + LOG("i_dft = %d\n", i_dft); + LOG("all_accepted = %d\n", all_accepted); + + if (all_accepted && n_draft == n_dradt_cur) { + LOG(" - max drafted tokens accepted - n_draft += 2\n"); + n_draft += 2; + } else if (all_accepted) { + LOG(" - partially drafted tokens accepted - no change\n"); + } else { + LOG(" - drafted token rejected - n_draft -= 1\n"); + n_draft = std::max(2, n_draft - 1); + } + } + drafted.clear(); drafted.push_back(id); break; } - if (drafted.size() > 0 && all_accepted) { - n_draft += 2; - LOG("all drafted tokens accepted, n_draft = %d\n", n_draft); - } else { - n_draft = std::max(2, n_draft - 1); - LOG("drafted token rejected, n_draft = %d\n", n_draft); - } - if (n_predict > params.n_predict || has_eos) { break; }