What if we do something crazy like add layers instead of removing them?

This commit is contained in:
KerfuffleV2 2023-10-19 18:00:15 -06:00
parent d6f35c7ca5
commit 0abf0064ca
2 changed files with 218 additions and 48 deletions

View file

@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0); llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
const int32_t n_layers = 32; // model layer count // model layer count
const int test_count = 6; // num perplexity chunks to run for each test const int32_t n_layers = 32;
const size_t prune_target = 4; // prune this many of the worst results each pass
// end tunables // num perplexity chunks to run for each test
const int test_count = 4;
// prune this many of the worst results each pass
const size_t prune_target = 2;
// start with all but first/last layers disabled and start adding them back
const bool anti_mode = true;
// **** end tunables ***
// 1 = attn, 2 = mlp, 3 = both // 1 = attn, 2 = mlp, 3 = both
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically. int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
skip_types.resize(n_layers); skip_types.resize(n_layers);
std::fill(skip_types.begin(), skip_types.end(), 0); std::fill(skip_types.begin(), skip_types.end(), 0);
std::vector<std::tuple<int32_t, int32_t, double>> pass_results; std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
std::vector<int32_t> worsts; std::vector<int32_t> extremes;
worsts.resize(n_layers); extremes.resize(n_layers);
std::fill(worsts.begin(), worsts.end(), 0); std::fill(extremes.begin(), extremes.end(), 0);
if (anti_mode) {
// No pointing in starting with first/last layer disabled.
skip_types[0] = 15;
skip_types[n_layers - 1] = 15;
skips.push_back(0); skips.push_back(0 + n_layers);
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
}
int32_t curr_best_layer = -1, curr_best_type = 0; int32_t curr_best_layer = -1, curr_best_type = 0;
double curr_best_ppl = -1, ref_ppl = -1; double curr_best_ppl = -1, ref_ppl = -1;
const int32_t mask = anti_mode ? 3 : 0;
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} }
if (skip_layer >= n_layers) { if (skip_layer >= n_layers) {
if (curr_best_layer == -1) break; if (curr_best_layer == -1) break;
if (pass_results.size() >= prune_target * 2) { if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
std::sort(pass_results.begin(), pass_results.end(), std::sort(pass_results.begin(), pass_results.end(),
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) { [](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
return std::get<2>(a) > std::get<2>(b); return std::get<2>(a) > std::get<2>(b);
} }
); );
const size_t num_prune = std::min(pass_results.size(), prune_target); const size_t num_prune = std::min(pass_results.size(), prune_target);
for (size_t temp = 0; temp < num_prune; temp++) { for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
int32_t lidx = std::get<0>(pass_results[temp]); int32_t lidx = std::get<0>(pass_results[temp]);
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue; if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
worsts[lidx] |= std::get<1>(pass_results[temp]); extremes[lidx] |= std::get<1>(pass_results[temp]);
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp])); printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
if (anti_mode) {
skip_types[lidx] |= std::get<1>(pass_results[temp]);
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx);
}
if (++pruned >= num_prune) break;
} }
} }
pass_results.clear(); pass_results.clear();
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f", printf("\n\nADD %c%3d - ppl vs ref %.4f",
int(label[curr_best_type]), curr_best_layer, int(label[curr_best_type]), curr_best_layer,
curr_best_ppl - ref_ppl); curr_best_ppl - ref_ppl);
if (curr_best_ppl > ref_ppl * 1.75) break; if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
skip_types[curr_best_layer] += curr_best_type; skip_types[curr_best_layer] += curr_best_type;
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) { skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
skips.push_back(curr_best_layer);
}
curr_best_layer = -1; curr_best_layer = -1;
curr_best_ppl = -1; curr_best_ppl = -1;
curr_best_type = 0; curr_best_type = 0;
skip_layer = n_layers; skip_layer = n_layers;
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2); skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
} }
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) { for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3); int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
logit_history.clear(); logit_history.clear();
prob_history.clear(); prob_history.clear();
int alive = 0;
for (int32_t i = 0; i < n_layers; i++) { for (int32_t i = 0; i < n_layers; i++) {
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0); layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
alive += !(layers[i] & 1) + !(layers[i] & 2);
} }
layers[n_layers] = -1; layers[n_layers] = -1;
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer); printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
for (const auto l : skips) { for (auto l : skips) {
printf("%c%d, ", int(label[skip_types[l] & 3]), l); printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
} }
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n", printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
skips.size() + 1, alive, n_layers * 2,
int(label[curr_best_type]), curr_best_layer, int(label[curr_best_type]), curr_best_layer,
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0, curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
test_t_total); test_t_total);
@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0 && skip_layer < 0 && skips.empty()) { if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count(); const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk); int total_seconds = (int)(t_total * n_chunk);
@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
} }
fflush(stdout); fflush(stdout);
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) { if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
i = test_count - 1; i = test_count - 1;
skip_types[skip_layer] |= test_skip_type << 2; skip_types[skip_layer] |= test_skip_type << 2;
if (curr_best_layer == -1 || ppl < curr_best_ppl) { if (curr_best_layer == -1 || ppl < curr_best_ppl) {

View file

@ -8,6 +8,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define DOFFS 10000
struct seq_draft { struct seq_draft {
bool active = false; bool active = false;
bool drafting = false; bool drafting = false;
@ -17,10 +19,31 @@ struct seq_draft {
std::vector<int> i_batch_tgt; std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
std::vector<float> tokens_p;
struct llama_sampling_context * ctx_sampling; struct llama_sampling_context * ctx_sampling;
}; };
static void save_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
// printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
// printf("<S>");
GGML_ASSERT(doffs + count <= 30);
memcpy(
v.data() + doffs * n_vocab,
llama_get_logits(ctx) + soffs * n_vocab,
sizeof(float) * size_t(n_vocab) * count);
}
static void restore_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
// printf("<R>");
// printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
GGML_ASSERT(soffs + count <= 30);
memcpy(
llama_get_logits(ctx) + doffs * n_vocab,
v.data() + soffs * n_vocab,
sizeof(float) * size_t(n_vocab) * count);
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
@ -37,8 +60,10 @@ int main(int argc, char ** argv) {
const int n_seq_dft = params.n_parallel; const int n_seq_dft = params.n_parallel;
// TODO: make this configurable // TODO: make this configurable
const float p_accept = 0.80f; // const float p_accept = 0.80f;
const float p_split = 0.10f; // const float p_split = 0.10f;
const float p_accept = 0.5f; // 0.80f;
const float p_split = p_accept / 8; // 0.10f;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log")); log_set_target(log_filename_generator("speculative", "log"));
@ -46,6 +71,8 @@ int main(int argc, char ** argv) {
log_dump_cmdline(argc, argv); log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
bool self_speculation = false;
// init llama.cpp // init llama.cpp
llama_backend_init(params.numa); llama_backend_init(params.numa);
@ -60,9 +87,18 @@ int main(int argc, char ** argv) {
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model // load the draft model
params.model = params.model_draft; if (params.model != params.model_draft) {
params.n_gpu_layers = params.n_gpu_layers_draft; params.model = params.model_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); params.n_gpu_layers = params.n_gpu_layers_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
} else {
self_speculation = true;
model_dft = model_tgt;
ctx_dft = ctx_tgt;
}
const int n_ctx = llama_n_ctx(ctx_tgt);
const int n_vocab = llama_n_vocab(model_tgt);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
@ -84,14 +120,33 @@ int main(int argc, char ** argv) {
fflush(stderr); fflush(stderr);
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
std::vector<float> logits_tgt, logits_dft;
const int n_input = inp.size(); const int n_input = inp.size();
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt with both models // eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); llama_batch_clear(batch_tgt);
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); logits_tgt.resize(n_vocab * 30);
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); logits_dft.resize(n_vocab * 30);
for (int i = 0; i < n_input - 1; i++) {
llama_batch_add(batch_tgt, inp[i], i, { 0 }, false);
}
llama_decode(ctx_tgt, batch_tgt);
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, inp.back(), n_input - 1, { 0 }, true);
llama_decode(ctx_tgt, batch_tgt);
save_logits(ctx_tgt, logits_tgt, n_vocab);
if (!self_speculation) {
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
} else {
// llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1);
}
// save_logits(ctx_dft, logits_dft, n_vocab, n_input);
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -104,6 +159,8 @@ int main(int argc, char ** argv) {
int n_predict = 0; int n_predict = 0;
int n_drafted = 0; int n_drafted = 0;
int n_accept = 0; int n_accept = 0;
int n_split = 0;
int n_bad_split = 0;
int n_past_tgt = inp.size(); int n_past_tgt = inp.size();
int n_past_dft = inp.size(); int n_past_dft = inp.size();
@ -124,8 +181,16 @@ int main(int argc, char ** argv) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams); drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
} }
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); // std::vector<int32_t> run_layers_dft = {
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); // 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
// 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
// 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0,
// 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, };
std::vector<int32_t> run_layers_dft = {
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1,
1, 0, 1, 0, 0, 0, -1, };
batch_dft.run_layers = run_layers_dft.data();
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();
@ -133,7 +198,11 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0; drafts[0].i_batch_tgt[0] = 0;
double avg_accepted = 0, avg_rejected = 0;
float min_accepted = 0, max_rejected = 0;
while (true) { while (true) {
LOG("*** Draft start\n");
// print current draft sequences // print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) { if (!drafts[s].active) {
@ -152,9 +221,11 @@ int main(int argc, char ** argv) {
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model // sample from the target model
restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
@ -179,11 +250,26 @@ int main(int argc, char ** argv) {
} }
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
s_keep = s; s_keep = s;
matches = true; matches = true;
LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]);
if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft];
else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]);
avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1);
avg_accepted /= 2;
} else { } else {
if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) {
if (i_dft == 0 && s > 0) n_bad_split++;
max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]);
avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1);
avg_rejected /= 2;
LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n",
s, i_dft, drafts[s].tokens[i_dft],
llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(),
id, token_str.c_str());
}
drafts[s].active = false; drafts[s].active = false;
} }
} }
@ -204,6 +290,18 @@ int main(int argc, char ** argv) {
{ {
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_rm(ctx_dft, s_keep + DOFFS, n_past_dft, -1);
llama_kv_cache_seq_rm(ctx_tgt, s_keep, n_past_tgt, -1);
if (s_keep != 0) {
llama_kv_cache_seq_cp(ctx_dft, s_keep + DOFFS, 0 + DOFFS, -1, -1);
llama_kv_cache_seq_cp(ctx_tgt, s_keep, 0, -1, -1);
}
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_rm(ctx_dft, s + DOFFS, -1, -1);
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
}
/*
llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0); llama_kv_cache_seq_keep(ctx_dft, 0);
@ -212,22 +310,28 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, s_keep); llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_kv_cache_seq_keep(ctx_tgt, 0);
*/
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false; drafts[s].active = false;
drafts[s].tokens.clear(); drafts[s].tokens.clear();
drafts[s].tokens_p.clear();
drafts[s].i_batch_tgt.clear(); drafts[s].i_batch_tgt.clear();
} }
// note: will be erased after the speculation phase // note: will be erased after the speculation phase
drafts[0].tokens.push_back(id); drafts[0].tokens.push_back(id);
drafts[0].tokens_p.push_back(0);
drafts[0].i_batch_tgt.push_back(0); drafts[0].i_batch_tgt.push_back(0);
llama_batch_clear(batch_dft); llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true);
LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode (ctx_dft, batch_dft); llama_decode (ctx_dft, batch_dft);
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
++n_past_dft; ++n_past_dft;
@ -254,6 +358,10 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_tgt); llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
// double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0;
LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n",
avg_accepted, avg_rejected, min_accepted, max_rejected);
// sample n_draft tokens from the draft model using tree-based sampling // sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) { for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0; batch_dft.n_tokens = 0;
@ -267,17 +375,24 @@ int main(int argc, char ** argv) {
continue; continue;
} }
restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
const auto & cur_p = drafts[s].ctx_sampling->cur; const auto & cur_p = drafts[s].ctx_sampling->cur;
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
if (cur_p[k].p < 1e-5f) continue;
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
} }
if (cur_p[0].p < p_accept) { double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); ? p_accept
: std::max(double(min_accepted * 0.98), avg_accepted * 0.75f);
// accept_threshold = 0.8;
if (cur_p[0].p < accept_threshold) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold);
drafts[s].drafting = false; drafts[s].drafting = false;
continue; continue;
} }
@ -286,11 +401,20 @@ int main(int argc, char ** argv) {
// attempt to split the branch if the probability is high enough // attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) { for (int f = 1; f < 8; ++f) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { // if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur); // if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) {
double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16
? p_split
: ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4)
* (n_seq_cur >= 2 ? 0.75 : 1.0) );
// split_threshold = 0.1;
if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) {
n_split++;
LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur,
cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str());
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur + DOFFS, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); llama_kv_cache_seq_cp(ctx_dft, s + DOFFS, n_seq_cur + DOFFS, -1, -1);
// all previous tokens from this branch are now also part of the new branch // all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) { for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@ -309,6 +433,7 @@ int main(int argc, char ** argv) {
drafts[n_seq_cur].skip = true; drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens; drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].tokens_p = drafts[s].tokens_p;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
@ -331,6 +456,7 @@ int main(int argc, char ** argv) {
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id); drafts[s].tokens.push_back(id);
drafts[s].tokens_p.push_back(cur_p[is].p);
// add unique drafted tokens to the target batch // add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
@ -340,7 +466,7 @@ int main(int argc, char ** argv) {
// add the token to the batch for batched decoding with the draft model // add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens; drafts[s].i_batch_dft = batch_dft.n_tokens;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true); llama_batch_add(batch_dft, id, n_past_cur, { s + DOFFS }, true);
if (batch_tgt.n_tokens > n_draft) { if (batch_tgt.n_tokens > n_draft) {
drafts[s].drafting = false; drafts[s].drafting = false;
@ -352,9 +478,18 @@ int main(int argc, char ** argv) {
if (batch_dft.n_tokens == 0) { if (batch_dft.n_tokens == 0) {
break; break;
} }
// LOG("Draft eval: %d\n", batch_dft.n_tokens);
// for (int x = 0; x < batch_dft.n_tokens; x++) {
// LOG("* %03d: seq %3d, pos %4d, token %6d '%s'", x,
// batch_dft.seq_id[x][0], batch_dft.pos[x],
// batch_dft.token[x], llama_token_to_piece(ctx_dft, batch_dft.token[x]).c_str());
// }
LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
// evaluate the drafted tokens on the draft model // evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft); llama_decode(ctx_dft, batch_dft);
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
++n_past_cur; ++n_past_cur;
++n_drafted; ++n_drafted;
@ -365,13 +500,17 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {
llama_kv_cache_seq_keep(ctx_tgt, 0); // llama_kv_cache_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
}
for (int s = 1; s < n_seq_dft; ++s) { for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
} }
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt)); LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt); llama_decode(ctx_tgt, batch_tgt);
save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens);
++n_past_tgt; ++n_past_tgt;
} }
@ -382,6 +521,7 @@ int main(int argc, char ** argv) {
} }
drafts[s].tokens.erase(drafts[s].tokens.begin()); drafts[s].tokens.erase(drafts[s].tokens.begin());
drafts[s].tokens_p.erase(drafts[s].tokens_p.begin());
} }
} }
@ -395,9 +535,13 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("drafted = %.3f%%\n", 100.0f * n_drafted / n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted); LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_split = %d\n", n_split);
LOG_TEE("n_badsplit= %d\n", n_bad_split);
LOG_TEE("\ndraft:\n"); LOG_TEE("\ndraft:\n");
llama_print_timings(ctx_dft); llama_print_timings(ctx_dft);
@ -415,8 +559,10 @@ int main(int argc, char ** argv) {
llama_free(ctx_tgt); llama_free(ctx_tgt);
llama_free_model(model_tgt); llama_free_model(model_tgt);
llama_free(ctx_dft); if (!self_speculation) {
llama_free_model(model_dft); llama_free(ctx_dft);
llama_free_model(model_dft);
}
llama_backend_free(); llama_backend_free();