diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 7559c0287..fb3b018f2 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0); - const int32_t n_layers = 32; // model layer count - const int test_count = 6; // num perplexity chunks to run for each test - const size_t prune_target = 4; // prune this many of the worst results each pass - // end tunables + // model layer count + const int32_t n_layers = 32; + + // num perplexity chunks to run for each test + const int test_count = 4; + + // prune this many of the worst results each pass + const size_t prune_target = 2; + + // start with all but first/last layers disabled and start adding them back + const bool anti_mode = true; + + // **** end tunables *** // 1 = attn, 2 = mlp, 3 = both int32_t test_skip_type = 0; // but don't mess with this, it's set automatically. @@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par skip_types.resize(n_layers); std::fill(skip_types.begin(), skip_types.end(), 0); std::vector> pass_results; - std::vector worsts; - worsts.resize(n_layers); - std::fill(worsts.begin(), worsts.end(), 0); + std::vector extremes; + extremes.resize(n_layers); + std::fill(extremes.begin(), extremes.end(), 0); + if (anti_mode) { + // No pointing in starting with first/last layer disabled. + skip_types[0] = 15; + skip_types[n_layers - 1] = 15; + skips.push_back(0); skips.push_back(0 + n_layers); + skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers); + } int32_t curr_best_layer = -1, curr_best_type = 0; double curr_best_ppl = -1, ref_ppl = -1; + const int32_t mask = anti_mode ? 3 : 0; int count = 0; double nll = 0.0; @@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } if (skip_layer >= n_layers) { if (curr_best_layer == -1) break; - if (pass_results.size() >= prune_target * 2) { + if (prune_target > 0 && pass_results.size() >= prune_target * 2) { std::sort(pass_results.begin(), pass_results.end(), [](const std::tuple & a, const std::tuple & b) { + if (anti_mode) return std::get<2>(b) > std::get<2>(a); return std::get<2>(a) > std::get<2>(b); } ); const size_t num_prune = std::min(pass_results.size(), prune_target); - for (size_t temp = 0; temp < num_prune; temp++) { + for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) { int32_t lidx = std::get<0>(pass_results[temp]); if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue; - worsts[lidx] |= std::get<1>(pass_results[temp]); - printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp])); + extremes[lidx] |= std::get<1>(pass_results[temp]); + printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx, + std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp])); + if (anti_mode) { + skip_types[lidx] |= std::get<1>(pass_results[temp]); + skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx); + } + if (++pruned >= num_prune) break; } } pass_results.clear(); - printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f", + printf("\n\nADD %c%3d - ppl vs ref %.4f", int(label[curr_best_type]), curr_best_layer, curr_best_ppl - ref_ppl); - if (curr_best_ppl > ref_ppl * 1.75) break; + if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break; skip_types[curr_best_layer] += curr_best_type; - if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) { - skips.push_back(curr_best_layer); - } + skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers); curr_best_layer = -1; curr_best_ppl = -1; curr_best_type = 0; skip_layer = n_layers; for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { - skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2); + skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2); } for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3); @@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par logit_history.clear(); prob_history.clear(); + int alive = 0; for (int32_t i = 0; i < n_layers; i++) { - layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0); + layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0)); + alive += !(layers[i] & 1) + !(layers[i] & 2); } layers[n_layers] = -1; printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer); - for (const auto l : skips) { - printf("%c%d, ", int(label[skip_types[l] & 3]), l); + for (auto l : skips) { + printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers); } - printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n", - skips.size() + 1, + printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n", + alive, n_layers * 2, int(label[curr_best_type]), curr_best_layer, curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0, test_t_total); @@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_end = std::chrono::high_resolution_clock::now(); - if (i == 0 && skip_layer < 0 && skips.empty()) { + if (i == 0 && skip_layer < 0 && ref_ppl < 0) { const float t_total = std::chrono::duration(t_end - t_start).count(); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); int total_seconds = (int)(t_total * n_chunk); @@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); - if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) { + if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) { i = test_count - 1; skip_types[skip_layer] |= test_skip_type << 2; if (curr_best_layer == -1 || ppl < curr_best_ppl) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 894321ce9..5830b4fb3 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -8,6 +8,8 @@ #include #include +#define DOFFS 10000 + struct seq_draft { bool active = false; bool drafting = false; @@ -17,10 +19,31 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; + std::vector tokens_p; struct llama_sampling_context * ctx_sampling; }; +static void save_logits(llama_context * ctx, std::vector & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) { + // printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs); + // printf(""); + GGML_ASSERT(doffs + count <= 30); + memcpy( + v.data() + doffs * n_vocab, + llama_get_logits(ctx) + soffs * n_vocab, + sizeof(float) * size_t(n_vocab) * count); +} + +static void restore_logits(llama_context * ctx, std::vector & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) { + // printf(""); + // printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs); + GGML_ASSERT(soffs + count <= 30); + memcpy( + llama_get_logits(ctx) + doffs * n_vocab, + v.data() + soffs * n_vocab, + sizeof(float) * size_t(n_vocab) * count); +} + int main(int argc, char ** argv) { gpt_params params; @@ -37,8 +60,10 @@ int main(int argc, char ** argv) { const int n_seq_dft = params.n_parallel; // TODO: make this configurable - const float p_accept = 0.80f; - const float p_split = 0.10f; + // const float p_accept = 0.80f; + // const float p_split = 0.10f; + const float p_accept = 0.5f; // 0.80f; + const float p_split = p_accept / 8; // 0.10f; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -46,6 +71,8 @@ int main(int argc, char ** argv) { log_dump_cmdline(argc, argv); #endif // LOG_DISABLE_LOGS + bool self_speculation = false; + // init llama.cpp llama_backend_init(params.numa); @@ -60,9 +87,18 @@ int main(int argc, char ** argv) { std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); // load the draft model - params.model = params.model_draft; - params.n_gpu_layers = params.n_gpu_layers_draft; - std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); + if (params.model != params.model_draft) { + params.model = params.model_draft; + params.n_gpu_layers = params.n_gpu_layers_draft; + std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); + } else { + self_speculation = true; + model_dft = model_tgt; + ctx_dft = ctx_tgt; + } + + const int n_ctx = llama_n_ctx(ctx_tgt); + const int n_vocab = llama_n_vocab(model_tgt); // tokenize the prompt std::vector inp; @@ -84,14 +120,33 @@ int main(int argc, char ** argv) { fflush(stderr); + llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); + llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); + std::vector logits_tgt, logits_dft; + const int n_input = inp.size(); const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); + llama_batch_clear(batch_tgt); + logits_tgt.resize(n_vocab * 30); + logits_dft.resize(n_vocab * 30); + for (int i = 0; i < n_input - 1; i++) { + llama_batch_add(batch_tgt, inp[i], i, { 0 }, false); + } + llama_decode(ctx_tgt, batch_tgt); + llama_batch_clear(batch_tgt); + llama_batch_add(batch_tgt, inp.back(), n_input - 1, { 0 }, true); + llama_decode(ctx_tgt, batch_tgt); + save_logits(ctx_tgt, logits_tgt, n_vocab); + if (!self_speculation) { + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS)); + } else { + // llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS)); + llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1); + } + // save_logits(ctx_dft, logits_dft, n_vocab, n_input); const auto t_enc_end = ggml_time_us(); @@ -104,6 +159,8 @@ int main(int argc, char ** argv) { int n_predict = 0; int n_drafted = 0; int n_accept = 0; + int n_split = 0; + int n_bad_split = 0; int n_past_tgt = inp.size(); int n_past_dft = inp.size(); @@ -124,8 +181,16 @@ int main(int argc, char ** argv) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); } - llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); + // std::vector run_layers_dft = { + // 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1, + // 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0, + // 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0, + // 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, }; + std::vector run_layers_dft = { + 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1, + 1, 0, 1, 0, 0, 0, -1, }; + + batch_dft.run_layers = run_layers_dft.data(); const auto t_dec_start = ggml_time_us(); @@ -133,7 +198,11 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt[0] = 0; + double avg_accepted = 0, avg_rejected = 0; + float min_accepted = 0, max_rejected = 0; + while (true) { + LOG("*** Draft start\n"); // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { @@ -152,9 +221,11 @@ int main(int argc, char ** argv) { LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); // sample from the target model + restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); + save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); @@ -179,11 +250,26 @@ int main(int argc, char ** argv) { } if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); s_keep = s; matches = true; + LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]); + if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft]; + else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]); + avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1); + avg_accepted /= 2; } else { + if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) { + if (i_dft == 0 && s > 0) n_bad_split++; + max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]); + avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1); + avg_rejected /= 2; + LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n", + s, i_dft, drafts[s].tokens[i_dft], + llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(), + id, token_str.c_str()); + } drafts[s].active = false; } } @@ -204,6 +290,18 @@ int main(int argc, char ** argv) { { LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); + llama_kv_cache_seq_rm(ctx_dft, s_keep + DOFFS, n_past_dft, -1); + llama_kv_cache_seq_rm(ctx_tgt, s_keep, n_past_tgt, -1); + if (s_keep != 0) { + llama_kv_cache_seq_cp(ctx_dft, s_keep + DOFFS, 0 + DOFFS, -1, -1); + llama_kv_cache_seq_cp(ctx_tgt, s_keep, 0, -1, -1); + } + for (int s = 1; s < n_seq_dft; ++s) { + llama_kv_cache_seq_rm(ctx_dft, s + DOFFS, -1, -1); + llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1); + } + + /* llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_dft, 0); @@ -212,22 +310,28 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, s_keep); llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_tgt, 0); + */ + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].active = false; drafts[s].tokens.clear(); + drafts[s].tokens_p.clear(); drafts[s].i_batch_tgt.clear(); } // note: will be erased after the speculation phase drafts[0].tokens.push_back(id); + drafts[0].tokens_p.push_back(0); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true); + + LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str()); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); llama_decode (ctx_dft, batch_dft); + save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens); ++n_past_dft; @@ -254,6 +358,10 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_tgt); llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + // double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0; + LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n", + avg_accepted, avg_rejected, min_accepted, max_rejected); + // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { batch_dft.n_tokens = 0; @@ -267,17 +375,24 @@ int main(int argc, char ** argv) { continue; } + restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft); llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft); const auto & cur_p = drafts[s].ctx_sampling->cur; for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { + if (cur_p[k].p < 1e-5f) continue; LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); + double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16 + ? p_accept + : std::max(double(min_accepted * 0.98), avg_accepted * 0.75f); + // accept_threshold = 0.8; + if (cur_p[0].p < accept_threshold) { + LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold); drafts[s].drafting = false; continue; } @@ -286,11 +401,20 @@ int main(int argc, char ** argv) { // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { - LOG("splitting seq %3d into %3d\n", s, n_seq_cur); + // if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { + // if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) { + double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16 + ? p_split + : ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4) + * (n_seq_cur >= 2 ? 0.75 : 1.0) ); + // split_threshold = 0.1; + if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) { + n_split++; + LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur, + cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str()); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_kv_cache_seq_rm(ctx_dft, n_seq_cur + DOFFS, -1, -1); + llama_kv_cache_seq_cp(ctx_dft, s + DOFFS, n_seq_cur + DOFFS, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -309,6 +433,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].tokens_p = drafts[s].tokens_p; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -331,6 +456,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); + drafts[s].tokens_p.push_back(cur_p[is].p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -340,7 +466,7 @@ int main(int argc, char ** argv) { // add the token to the batch for batched decoding with the draft model drafts[s].i_batch_dft = batch_dft.n_tokens; - llama_batch_add(batch_dft, id, n_past_cur, { s }, true); + llama_batch_add(batch_dft, id, n_past_cur, { s + DOFFS }, true); if (batch_tgt.n_tokens > n_draft) { drafts[s].drafting = false; @@ -352,9 +478,18 @@ int main(int argc, char ** argv) { if (batch_dft.n_tokens == 0) { break; } + // LOG("Draft eval: %d\n", batch_dft.n_tokens); + // for (int x = 0; x < batch_dft.n_tokens; x++) { + // LOG("* %03d: seq %3d, pos %4d, token %6d '%s'", x, + // batch_dft.seq_id[x][0], batch_dft.pos[x], + // batch_dft.token[x], llama_token_to_piece(ctx_dft, batch_dft.token[x]).c_str()); + // } + + LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str()); // evaluate the drafted tokens on the draft model llama_decode(ctx_dft, batch_dft); + save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens); ++n_past_cur; ++n_drafted; @@ -365,13 +500,17 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + // llama_kv_cache_seq_keep(ctx_tgt, 0); + for (int s = 1; s < n_seq_dft; ++s) { + llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1); + } for (int s = 1; s < n_seq_dft; ++s) { llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); } - //LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt)); + LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); + save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens); ++n_past_tgt; } @@ -382,6 +521,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].tokens_p.erase(drafts[s].tokens_p.begin()); } } @@ -395,9 +535,13 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_predict = %d\n", n_predict); + LOG_TEE("drafted = %.3f%%\n", 100.0f * n_drafted / n_predict); LOG_TEE("n_drafted = %d\n", n_drafted); LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + LOG_TEE("n_draft = %d\n", n_draft); + LOG_TEE("n_split = %d\n", n_split); + LOG_TEE("n_badsplit= %d\n", n_bad_split); LOG_TEE("\ndraft:\n"); llama_print_timings(ctx_dft); @@ -415,8 +559,10 @@ int main(int argc, char ** argv) { llama_free(ctx_tgt); llama_free_model(model_tgt); - llama_free(ctx_dft); - llama_free_model(model_dft); + if (!self_speculation) { + llama_free(ctx_dft); + llama_free_model(model_dft); + } llama_backend_free();