From e6dd81f0bc82afea13350b90f273016c2f6fe8f0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Oct 2023 17:04:31 +0300 Subject: [PATCH] speculative : fix the n_drafted fix + p constants --- examples/speculative/speculative.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 48cdd4d31..53f42fad8 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -36,6 +36,10 @@ int main(int argc, char ** argv) { // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; + // TODO: make this configurable + const float p_accept = 0.4f; + const float p_split = 0.3f; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -272,8 +276,7 @@ int main(int argc, char ** argv) { k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - // TODO: make this configurable - if (cur_p[0].p < 0.4) { + if (cur_p[0].p < p_accept) { LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p); drafts[s].drafting = false; continue; @@ -283,8 +286,7 @@ int main(int argc, char ** argv) { // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - // TODO: make this configurable - if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) { + if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); @@ -364,7 +366,9 @@ int main(int argc, char ** argv) { } // account for the last drafted token that we didn't evaluate - ++n_drafted; + if (batch_tgt.n_tokens > n_draft) { + ++n_drafted; + } // evaluate the target model on the drafted tokens {