From 13e08d0efa8f1803cd71a8413eb531f9cc3dfb2e Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Mon, 23 Oct 2023 02:40:37 -0600 Subject: [PATCH] Sync latest changes --- examples/perplexity/perplexity.cpp | 11 +- examples/speculative/speculative.cpp | 301 ++++++++++++++++++++++----- 2 files changed, 258 insertions(+), 54 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 62d55fee5..c9b393caa 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -397,6 +397,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } ); const size_t num_prune = std::min(pass_results.size(), prune_target); + if (num_prune > 0) printf("\nPruning: "); for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) { int32_t lidx = std::get<0>(pass_results[temp]); if (anti_mode) { @@ -405,17 +406,17 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue; extremes[lidx] |= std::get<1>(pass_results[temp]); - printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx, + printf("[%zu: %d (%d) - %.2f], ", pruned + 1, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp])); if (++pruned >= num_prune) break; } } pass_results.clear(); - printf("\n\nADD %c%3d - ppl vs ref %.4f", + printf("\n\nADD %c%3d - ppl vs ref %.4f - cur:[", int(label[curr_best_type]), curr_best_layer, curr_best_ppl - ref_ppl); if (!anti_mode) { - if (curr_best_ppl > ref_ppl * 1.75) break; + // if (curr_best_ppl > ref_ppl * 1.75) break; skip_types[curr_best_layer] += curr_best_type; skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers); } @@ -426,6 +427,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2); } + for (int32_t i = 0; i < n_layers; i++) { + const int val = mask ^ (skip_types[i] & 3); + printf("%d%s", val, i < n_layers - 1 ? ", " : "]"); + } for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3); // printf("||%d, %d\n", new_sl, curr_skipped); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5830b4fb3..3d8dc1347 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -15,6 +15,8 @@ struct seq_draft { bool drafting = false; bool skip = false; + int split_pos = 0; + int i_batch_dft = 0; std::vector i_batch_tgt; @@ -27,7 +29,7 @@ struct seq_draft { static void save_logits(llama_context * ctx, std::vector & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) { // printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs); // printf(""); - GGML_ASSERT(doffs + count <= 30); + GGML_ASSERT(doffs + count < 64); memcpy( v.data() + doffs * n_vocab, llama_get_logits(ctx) + soffs * n_vocab, @@ -37,13 +39,47 @@ static void save_logits(llama_context * ctx, std::vector & v, const int n static void restore_logits(llama_context * ctx, std::vector & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) { // printf(""); // printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs); - GGML_ASSERT(soffs + count <= 30); + GGML_ASSERT(soffs + count < 64); memcpy( llama_get_logits(ctx) + doffs * n_vocab, v.data() + soffs * n_vocab, sizeof(float) * size_t(n_vocab) * count); } +static llama_token_data_array normalize_candidates(const float * logits, const int n_vocab, std::vector & cur) { + cur.reserve(n_vocab); + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + llama_sample_top_k(NULL, &cur_p, 100, 1); + llama_sample_softmax(NULL, &cur_p); + cur.resize(cur_p.size); + return cur_p; +} + +static int32_t find_normalized(const llama_token_data_array & tda, const llama_token id) { + llama_token_data *item = tda.data; + + for (int32_t i = 0; i < tda.size; i++, item++) + if (item->id == id) return i; + return -1; +} + +static double running_average(double & cur, double val, double n = 20) { + if (cur < 1e-5f) { + cur = val; + return cur; + } + // New average = old average * (n-1)/n + new value /n + cur = cur * (n - 1) / n + val / n; + return cur; +} + + int main(int argc, char ** argv) { gpt_params params; @@ -62,8 +98,8 @@ int main(int argc, char ** argv) { // TODO: make this configurable // const float p_accept = 0.80f; // const float p_split = 0.10f; - const float p_accept = 0.5f; // 0.80f; - const float p_split = p_accept / 8; // 0.10f; + const float p_accept = 0.75f; // 0.80f; + const float p_split = 0.6f; // p_accept / 8; // 0.10f; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -130,8 +166,8 @@ int main(int argc, char ** argv) { // eval the prompt with both models llama_batch_clear(batch_tgt); - logits_tgt.resize(n_vocab * 30); - logits_dft.resize(n_vocab * 30); + logits_tgt.resize(n_vocab * 64); + logits_dft.resize(n_vocab * 64); for (int i = 0; i < n_input - 1; i++) { llama_batch_add(batch_tgt, inp[i], i, { 0 }, false); } @@ -146,7 +182,7 @@ int main(int argc, char ** argv) { // llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS)); llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1); } - // save_logits(ctx_dft, logits_dft, n_vocab, n_input); + save_logits(ctx_dft, logits_dft, n_vocab, n_input); const auto t_enc_end = ggml_time_us(); @@ -161,6 +197,11 @@ int main(int argc, char ** argv) { int n_accept = 0; int n_split = 0; int n_bad_split = 0; + int n_dup_split = 0; + int n_eff_split = 0; + int max_streak = 0; + + int64_t t_dft_sample = 0, t_dft_gen = 0, t_dft_accept = 0, t_tgt_predict = 0; int n_past_tgt = inp.size(); int n_past_dft = inp.size(); @@ -170,26 +211,35 @@ int main(int argc, char ** argv) { // target model sampling context struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_dft_sampling = llama_sampling_init(params.sparams); + std::vector normalized_candidates; + normalized_candidates.reserve(n_vocab); + llama_token_data_array normalized_p; // draft sequence data std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sparams.temp = std::max(0.01f, params.sparams.temp); + // params.sparams.temp = std::max(0.01f, params.sparams.temp); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); } - // std::vector run_layers_dft = { - // 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1, - // 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0, - // 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0, - // 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, }; + // 70B (80 layers) skips example std::vector run_layers_dft = { - 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1, - 1, 0, 1, 0, 0, 0, -1, }; + 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1, + 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0, + 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0, + 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, }; + // 3B (26 layers) skips example + // std::vector run_layers_dft = { + // 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 3, 0, 2, 3, 3, 1, 0, 2, 0, 1, 1, 2, 0, 0, + // // 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 2, 1, 3, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 2, 0, 1, + // -1, }; + + // NOTE: Comment this line out to disable skipping. batch_dft.run_layers = run_layers_dft.data(); const auto t_dec_start = ggml_time_us(); @@ -198,8 +248,13 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt[0] = 0; - double avg_accepted = 0, avg_rejected = 0; - float min_accepted = 0, max_rejected = 0; + double avg_accepted = 0, avg_rejected = 0, tgt_avg_accepted = 0; + double avg_accept_delta = 0; + float min_accepted = 0, max_rejected = 0, tgt_min_accepted = 0; + + int64_t t_cur; + + std::vector>> doubt; while (true) { LOG("*** Draft start\n"); @@ -217,15 +272,37 @@ int main(int argc, char ** argv) { int i_dft = 0; int s_keep = 0; + float tgt_last_norm = 0, tgt_last_best_norm = 0, tgt_last_orig = 0; + while (true) { LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); // sample from the target model restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]); + normalized_p = normalize_candidates(llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]), n_vocab, normalized_candidates); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]); + int32_t norm_pos = find_normalized(normalized_p, id); + int32_t orig_pos = find_normalized({ctx_sampling->cur.data(), ctx_sampling->cur.size(), false}, id); + if (norm_pos >= 0) { + tgt_last_norm = normalized_candidates[norm_pos].p; + tgt_last_best_norm = normalized_candidates[0].p; + running_average(tgt_avg_accepted, tgt_last_norm); + tgt_min_accepted = tgt_min_accepted < 1e-4 + ? tgt_last_norm + : std::min(tgt_min_accepted, tgt_last_norm); + } else { + tgt_last_norm = tgt_last_best_norm = tgt_avg_accepted; + } + if (orig_pos >= 0) { + tgt_last_orig = ctx_sampling->cur[orig_pos].p; + } + LOG("target sampled (%d, '%s') orig_p=%5.4f, norm_p=%5.4f\n", + id, llama_token_to_piece(ctx_tgt, id).c_str(), + orig_pos >= 0 ? ctx_sampling->cur[orig_pos].p : -1, + norm_pos >= 0 ? normalized_candidates[norm_pos].p : -1); llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); - save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); @@ -245,26 +322,30 @@ int main(int argc, char ** argv) { bool matches = false; for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { + if (!drafts[s].active || i_dft < drafts[s].split_pos) { continue; } if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", + i_dft, s, id, token_str.c_str()); + if (i_dft == 0 && s > 0) { + if (matches) n_dup_split++; + else n_eff_split++; + } s_keep = s; matches = true; LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]); if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft]; else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]); - avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1); - avg_accepted /= 2; + running_average(avg_accepted, drafts[s].tokens_p[i_dft]); + running_average(avg_accept_delta, tgt_last_norm - drafts[s].tokens_p[i_dft]); } else { if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) { if (i_dft == 0 && s > 0) n_bad_split++; max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]); - avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1); - avg_rejected /= 2; + running_average(avg_rejected, drafts[s].tokens_p[i_dft]); LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n", s, i_dft, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(), @@ -279,8 +360,27 @@ int main(int argc, char ** argv) { ++n_past_tgt; ++n_past_dft; ++i_dft; - + max_streak = std::max(max_streak, i_dft); continue; + } else { + for (size_t seqnum = 0; seqnum < doubt.size(); seqnum++) { + const std::vector> & sdoubt = doubt[seqnum]; + if (sdoubt.size() <= i_dft) continue; + const std::vector & sidoubt = sdoubt[i_dft]; + for (size_t cidx = 0; cidx < sidoubt.size(); cidx++) { + if (sidoubt[cidx].id == id) { + LOG("Shoulda picked seq %3zu, pos %4d, candidate %2zu @ p %5.4f: %6d '%s'\n", + seqnum, i_dft, cidx, sidoubt[cidx].p, + id, token_str.c_str()); + running_average(avg_accepted, sidoubt[cidx].p); + if (cidx < 2) { + running_average(avg_accept_delta, tgt_last_norm - sidoubt[cidx].p); + min_accepted = min_accepted < 1e-5f ? sidoubt[cidx].p : std::min(min_accepted, sidoubt[cidx].p); + } + break; + } + } + } } } @@ -315,6 +415,7 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].split_pos = 0; drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].tokens_p.clear(); @@ -327,10 +428,18 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_dft); llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true); + if (self_speculation) { + // Copy KV items from non-brain-damaged model... Doesn't seem to help. + llama_kv_cache_seq_rm(ctx_dft, 0 + DOFFS, 0, n_past_dft - 2); + llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, n_past_dft - 2); + // llama_kv_cache_seq_rm(ctx_dft, 0 + DOFFS, n_past_dft - 1, -1); + // llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, n_past_dft - 1, -1); + } LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str()); - + t_cur = ggml_time_us(); llama_decode (ctx_dft, batch_dft); + t_dft_accept += ggml_time_us() - t_cur; save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens); ++n_past_dft; @@ -358,9 +467,14 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_tgt); llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + avg_rejected = std::max(0.05, std::min(avg_accepted - 0.05, avg_rejected)); + avg_accepted = std::max(0.05, std::max(avg_rejected + 0.05, avg_accepted)); // double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0; - LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n", - avg_accepted, avg_rejected, min_accepted, max_rejected); + LOG("STATS: Avg tacc/dacc/drej: %3.5f / %3.5f / %3.5f | Min dacc/min tacc/max drej: %3.5f / %3.5f / %3.5f | delta %3.5f | max streak %d | n_dft/pred/acc: %d / %d / %d\n", + tgt_avg_accepted, avg_accepted, avg_rejected, min_accepted, tgt_min_accepted, max_rejected, avg_accept_delta, max_streak, + n_drafted, n_predict, n_accept); + doubt.clear(); + doubt.resize(n_seq_dft); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -371,43 +485,116 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { + double accept_threshold, split_threshold; + if (!drafts[s].drafting || drafts[s].skip) { continue; } + doubt[s].push_back({}); - restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft); - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); - save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft); + if (avg_rejected == 0 || avg_rejected == 0 || n_drafted + n_predict < 6) { + accept_threshold = std::max(0.6f, tgt_last_norm); + } else { - const auto & cur_p = drafts[s].ctx_sampling->cur; + accept_threshold = (tgt_avg_accepted - avg_accept_delta) * 0.3; + accept_threshold *= std::min(0.8, std::max(0.1, double(tgt_last_norm * 1.0))); + accept_threshold = std::max(double(min_accepted) * 1.1, accept_threshold); + accept_threshold = std::max(std::max(avg_accepted * 0.9, avg_rejected * 1.1), accept_threshold); + accept_threshold += 1.0 - (1.2 * n_accept / n_drafted); + accept_threshold *= (1.3 - (std::min(n_seq_cur + i, 6) * 0.1)); + // + // accept_threshold = (tgt_avg_accepted - avg_accept_delta) * 0.3; + // accept_threshold *= std::min(0.8, std::max(0.1, double(tgt_last_norm * 1.0))); + // accept_threshold = std::max(double(min_accepted) * 1.1, accept_threshold); + // accept_threshold = std::max(std::max(avg_accepted * 0.9, avg_rejected * 1.1), accept_threshold); + // accept_threshold += 1.0 - (1.2 * n_accept / n_drafted); + // accept_threshold *= (0.7 + (std::min(n_seq_cur + i, 5) * 0.1)); - for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { - if (cur_p[k].p < 1e-5f) continue; - LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16 - ? p_accept - : std::max(double(min_accepted * 0.98), avg_accepted * 0.75f); - // accept_threshold = 0.8; - if (cur_p[0].p < accept_threshold) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold); + std::vector cur_p; + { + llama_token d_id; + std::vector already_picked; + float * logits = NULL; + + t_cur = ggml_time_us(); + for (int cidx = 0; cidx < 9; cidx++) { + llama_sampling_cp(drafts[s].ctx_sampling, ctx_dft_sampling); + restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft); + logits = llama_get_logits(ctx_dft); + normalized_p = normalize_candidates(logits, n_vocab, normalized_candidates); + for (size_t x = 0; x < std::min(normalized_p.size, size_t(10)); x++) + doubt[s].back().push_back(normalized_p.data[x]); + for (const auto & tid : already_picked) + logits[tid] = std::numeric_limits::infinity() * -1; + d_id = llama_sampling_sample(ctx_dft_sampling, ctx_dft, NULL); + already_picked.push_back(d_id); + int32_t norm_pos = find_normalized(normalized_p, d_id); + if (norm_pos < 0) continue; + llama_token_data norm = normalized_candidates[norm_pos]; + if (norm.p < 0.2) continue; + if (ctx_dft_sampling->params.temp <= 0) { + llama_token_data_array tda = { ctx_dft_sampling->cur.data(), ctx_dft_sampling->cur.size(), false }; + llama_sample_top_k(ctx_dft, &tda, 100, 1); + llama_sample_softmax(ctx_dft, &tda); + ctx_dft_sampling->cur.resize(tda.size); + } + + + llama_token_data found; + found.id = -1; + for (const llama_token_data & td : ctx_dft_sampling->cur) { + if (td.id == d_id) { + found = td; + break; + } + } + GGML_ASSERT(found.id != -1); + LOG(" ** draft candidate %3d for seq %3d, pos %3d: %6d (%4.3f, norm %4.3f) '%s'\n", + cidx, s, i, found.id, found.p, norm_pos >= 0 ? normalized_candidates[norm_pos].p : -1, + llama_token_to_piece(ctx_dft, found.id).c_str()); + if (found.p < 0.3) continue; + if (norm.p < 1e-2f) break; + cur_p.push_back(normalized_candidates[norm_pos]); + } + + if (cur_p.size() > 1) { + std::sort(cur_p.begin() + 1, cur_p.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.p > b.p; + } + ); + } + + } + + t_dft_sample += ggml_time_us() - t_cur; + + if (cur_p.empty()) { + LOG("stopping drafting for seq %3d, no viable candidates (%5.3f) \n", s, accept_threshold); + drafts[s].drafting = false; + continue; + } else if (cur_p[0].p < accept_threshold && (cur_p[0].p + (cur_p.size() < 2 ? 0 : cur_p[1].p)) < accept_threshold * 1.3) { + LOG("stopping drafting for seq %3d, pos %3d - probability too low: %.3f < %.3f\n", s, i, cur_p[0].p, accept_threshold); drafts[s].drafting = false; continue; } + if (cur_p[0].p < accept_threshold) { + split_threshold = 0.0; + } else { + split_threshold = cur_p[0].p / 10.0; + // split_threshold = std::max(0.01, cur_p[0].p * (n_seq_cur + i > 1 ? 0.15 : 0.2)); + } + std::vector sa(1, s); + + + // LOG("Check splits: %zu\n", cur_p.size()); // attempt to split the branch if the probability is high enough - for (int f = 1; f < 8; ++f) { - // if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { - // if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) { - double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16 - ? p_split - : ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4) - * (n_seq_cur >= 2 ? 0.75 : 1.0) ); - // split_threshold = 0.1; + for (int f = 1; f < std::min(8, int(cur_p.size()) - 1); ++f) { if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) { n_split++; LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur, @@ -428,6 +615,7 @@ int main(int argc, char ** argv) { } // copy the draft state + drafts[n_seq_cur].split_pos = i; drafts[n_seq_cur].active = true; drafts[n_seq_cur].drafting = true; drafts[n_seq_cur].skip = true; @@ -443,6 +631,8 @@ int main(int argc, char ** argv) { n_seq_cur++; } else { + LOG("Not splitting seq %3d into %3d, choice %2d @ %6d (%8.3f) '%s'\n", s, n_seq_cur, f, + cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str()); break; } } @@ -488,7 +678,9 @@ int main(int argc, char ** argv) { LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str()); // evaluate the drafted tokens on the draft model + t_cur = ggml_time_us(); llama_decode(ctx_dft, batch_dft); + t_dft_gen += ggml_time_us() - t_cur; save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens); ++n_past_cur; ++n_drafted; @@ -509,7 +701,9 @@ int main(int argc, char ** argv) { } LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); + t_cur = ggml_time_us(); llama_decode(ctx_tgt, batch_tgt); + t_tgt_predict += ggml_time_us() - t_cur; save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens); ++n_past_tgt; } @@ -531,7 +725,9 @@ int main(int argc, char ** argv) { LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - + LOG_TEE("times: target predict: %5.3f, draft gen/accept/sample: %5.3f / %5.3f / %5.3f\n", + t_tgt_predict / 1e6f, t_dft_gen / 1e6f, t_dft_accept / 1e6f, t_dft_sample / 1e6f); +// int64_t t_dft_sample = 0, t_dft_gen = 0, t_dft_accept = 0, t_tgt_predict = 0; LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_predict = %d\n", n_predict); @@ -541,7 +737,10 @@ int main(int argc, char ** argv) { LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_split = %d\n", n_split); + LOG_TEE("n_effsplit= %d\n", n_eff_split); LOG_TEE("n_badsplit= %d\n", n_bad_split); + LOG_TEE("n_dupsplit= %d\n", n_dup_split); + LOG_TEE("max streak= %d\n", max_streak); LOG_TEE("\ndraft:\n"); llama_print_timings(ctx_dft);