What if we do something crazy like add layers instead of removing them?
This commit is contained in:
parent
d6f35c7ca5
commit
0abf0064ca
2 changed files with 218 additions and 48 deletions
|
@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
|
||||
|
||||
const int32_t n_layers = 32; // model layer count
|
||||
const int test_count = 6; // num perplexity chunks to run for each test
|
||||
const size_t prune_target = 4; // prune this many of the worst results each pass
|
||||
// end tunables
|
||||
// model layer count
|
||||
const int32_t n_layers = 32;
|
||||
|
||||
// num perplexity chunks to run for each test
|
||||
const int test_count = 4;
|
||||
|
||||
// prune this many of the worst results each pass
|
||||
const size_t prune_target = 2;
|
||||
|
||||
// start with all but first/last layers disabled and start adding them back
|
||||
const bool anti_mode = true;
|
||||
|
||||
// **** end tunables ***
|
||||
|
||||
// 1 = attn, 2 = mlp, 3 = both
|
||||
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
|
||||
|
@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
skip_types.resize(n_layers);
|
||||
std::fill(skip_types.begin(), skip_types.end(), 0);
|
||||
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
|
||||
std::vector<int32_t> worsts;
|
||||
worsts.resize(n_layers);
|
||||
std::fill(worsts.begin(), worsts.end(), 0);
|
||||
std::vector<int32_t> extremes;
|
||||
extremes.resize(n_layers);
|
||||
std::fill(extremes.begin(), extremes.end(), 0);
|
||||
if (anti_mode) {
|
||||
// No pointing in starting with first/last layer disabled.
|
||||
skip_types[0] = 15;
|
||||
skip_types[n_layers - 1] = 15;
|
||||
skips.push_back(0); skips.push_back(0 + n_layers);
|
||||
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
|
||||
}
|
||||
int32_t curr_best_layer = -1, curr_best_type = 0;
|
||||
double curr_best_ppl = -1, ref_ppl = -1;
|
||||
const int32_t mask = anti_mode ? 3 : 0;
|
||||
|
||||
int count = 0;
|
||||
double nll = 0.0;
|
||||
|
@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
}
|
||||
if (skip_layer >= n_layers) {
|
||||
if (curr_best_layer == -1) break;
|
||||
if (pass_results.size() >= prune_target * 2) {
|
||||
if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
|
||||
std::sort(pass_results.begin(), pass_results.end(),
|
||||
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
|
||||
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
|
||||
return std::get<2>(a) > std::get<2>(b);
|
||||
}
|
||||
);
|
||||
const size_t num_prune = std::min(pass_results.size(), prune_target);
|
||||
for (size_t temp = 0; temp < num_prune; temp++) {
|
||||
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
|
||||
int32_t lidx = std::get<0>(pass_results[temp]);
|
||||
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
|
||||
worsts[lidx] |= std::get<1>(pass_results[temp]);
|
||||
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
|
||||
extremes[lidx] |= std::get<1>(pass_results[temp]);
|
||||
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
|
||||
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
|
||||
if (anti_mode) {
|
||||
skip_types[lidx] |= std::get<1>(pass_results[temp]);
|
||||
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx);
|
||||
}
|
||||
if (++pruned >= num_prune) break;
|
||||
}
|
||||
}
|
||||
pass_results.clear();
|
||||
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
|
||||
printf("\n\nADD %c%3d - ppl vs ref %.4f",
|
||||
int(label[curr_best_type]), curr_best_layer,
|
||||
curr_best_ppl - ref_ppl);
|
||||
if (curr_best_ppl > ref_ppl * 1.75) break;
|
||||
if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
|
||||
skip_types[curr_best_layer] += curr_best_type;
|
||||
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
|
||||
skips.push_back(curr_best_layer);
|
||||
}
|
||||
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
|
||||
curr_best_layer = -1;
|
||||
curr_best_ppl = -1;
|
||||
curr_best_type = 0;
|
||||
skip_layer = n_layers;
|
||||
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2);
|
||||
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
|
||||
}
|
||||
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
|
||||
|
@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
logit_history.clear();
|
||||
prob_history.clear();
|
||||
|
||||
int alive = 0;
|
||||
for (int32_t i = 0; i < n_layers; i++) {
|
||||
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
|
||||
layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
|
||||
alive += !(layers[i] & 1) + !(layers[i] & 2);
|
||||
}
|
||||
layers[n_layers] = -1;
|
||||
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
|
||||
for (const auto l : skips) {
|
||||
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
|
||||
for (auto l : skips) {
|
||||
printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
|
||||
}
|
||||
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
|
||||
skips.size() + 1,
|
||||
printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
|
||||
alive, n_layers * 2,
|
||||
int(label[curr_best_type]), curr_best_layer,
|
||||
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
|
||||
test_t_total);
|
||||
|
@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0 && skip_layer < 0 && skips.empty()) {
|
||||
if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||
int total_seconds = (int)(t_total * n_chunk);
|
||||
|
@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||
}
|
||||
fflush(stdout);
|
||||
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) {
|
||||
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
|
||||
i = test_count - 1;
|
||||
skip_types[skip_layer] |= test_skip_type << 2;
|
||||
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#define DOFFS 10000
|
||||
|
||||
struct seq_draft {
|
||||
bool active = false;
|
||||
bool drafting = false;
|
||||
|
@ -17,10 +19,31 @@ struct seq_draft {
|
|||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
std::vector<float> tokens_p;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling;
|
||||
};
|
||||
|
||||
static void save_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
|
||||
// printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
|
||||
// printf("<S>");
|
||||
GGML_ASSERT(doffs + count <= 30);
|
||||
memcpy(
|
||||
v.data() + doffs * n_vocab,
|
||||
llama_get_logits(ctx) + soffs * n_vocab,
|
||||
sizeof(float) * size_t(n_vocab) * count);
|
||||
}
|
||||
|
||||
static void restore_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
|
||||
// printf("<R>");
|
||||
// printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
|
||||
GGML_ASSERT(soffs + count <= 30);
|
||||
memcpy(
|
||||
llama_get_logits(ctx) + doffs * n_vocab,
|
||||
v.data() + soffs * n_vocab,
|
||||
sizeof(float) * size_t(n_vocab) * count);
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
|
@ -37,8 +60,10 @@ int main(int argc, char ** argv) {
|
|||
const int n_seq_dft = params.n_parallel;
|
||||
|
||||
// TODO: make this configurable
|
||||
const float p_accept = 0.80f;
|
||||
const float p_split = 0.10f;
|
||||
// const float p_accept = 0.80f;
|
||||
// const float p_split = 0.10f;
|
||||
const float p_accept = 0.5f; // 0.80f;
|
||||
const float p_split = p_accept / 8; // 0.10f;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("speculative", "log"));
|
||||
|
@ -46,6 +71,8 @@ int main(int argc, char ** argv) {
|
|||
log_dump_cmdline(argc, argv);
|
||||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
bool self_speculation = false;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
|
@ -60,9 +87,18 @@ int main(int argc, char ** argv) {
|
|||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||
|
||||
// load the draft model
|
||||
params.model = params.model_draft;
|
||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||
if (params.model != params.model_draft) {
|
||||
params.model = params.model_draft;
|
||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||
} else {
|
||||
self_speculation = true;
|
||||
model_dft = model_tgt;
|
||||
ctx_dft = ctx_tgt;
|
||||
}
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_tgt);
|
||||
const int n_vocab = llama_n_vocab(model_tgt);
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
|
@ -84,14 +120,33 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stderr);
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
||||
std::vector<float> logits_tgt, logits_dft;
|
||||
|
||||
const int n_input = inp.size();
|
||||
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt with both models
|
||||
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
|
||||
llama_batch_clear(batch_tgt);
|
||||
logits_tgt.resize(n_vocab * 30);
|
||||
logits_dft.resize(n_vocab * 30);
|
||||
for (int i = 0; i < n_input - 1; i++) {
|
||||
llama_batch_add(batch_tgt, inp[i], i, { 0 }, false);
|
||||
}
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
llama_batch_clear(batch_tgt);
|
||||
llama_batch_add(batch_tgt, inp.back(), n_input - 1, { 0 }, true);
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
save_logits(ctx_tgt, logits_tgt, n_vocab);
|
||||
if (!self_speculation) {
|
||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
|
||||
} else {
|
||||
// llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
|
||||
llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1);
|
||||
}
|
||||
// save_logits(ctx_dft, logits_dft, n_vocab, n_input);
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
|
@ -104,6 +159,8 @@ int main(int argc, char ** argv) {
|
|||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
int n_accept = 0;
|
||||
int n_split = 0;
|
||||
int n_bad_split = 0;
|
||||
|
||||
int n_past_tgt = inp.size();
|
||||
int n_past_dft = inp.size();
|
||||
|
@ -124,8 +181,16 @@ int main(int argc, char ** argv) {
|
|||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
||||
// std::vector<int32_t> run_layers_dft = {
|
||||
// 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
|
||||
// 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
|
||||
// 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0,
|
||||
// 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, };
|
||||
std::vector<int32_t> run_layers_dft = {
|
||||
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1,
|
||||
1, 0, 1, 0, 0, 0, -1, };
|
||||
|
||||
batch_dft.run_layers = run_layers_dft.data();
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
|
@ -133,7 +198,11 @@ int main(int argc, char ** argv) {
|
|||
drafts[0].i_batch_tgt.resize(1);
|
||||
drafts[0].i_batch_tgt[0] = 0;
|
||||
|
||||
double avg_accepted = 0, avg_rejected = 0;
|
||||
float min_accepted = 0, max_rejected = 0;
|
||||
|
||||
while (true) {
|
||||
LOG("*** Draft start\n");
|
||||
// print current draft sequences
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
|
@ -152,9 +221,11 @@ int main(int argc, char ** argv) {
|
|||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// sample from the target model
|
||||
restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
||||
save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||
|
||||
|
@ -179,11 +250,26 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
||||
LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
||||
|
||||
s_keep = s;
|
||||
matches = true;
|
||||
LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]);
|
||||
if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft];
|
||||
else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]);
|
||||
avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1);
|
||||
avg_accepted /= 2;
|
||||
} else {
|
||||
if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) {
|
||||
if (i_dft == 0 && s > 0) n_bad_split++;
|
||||
max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]);
|
||||
avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1);
|
||||
avg_rejected /= 2;
|
||||
LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n",
|
||||
s, i_dft, drafts[s].tokens[i_dft],
|
||||
llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(),
|
||||
id, token_str.c_str());
|
||||
}
|
||||
drafts[s].active = false;
|
||||
}
|
||||
}
|
||||
|
@ -204,6 +290,18 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, s_keep + DOFFS, n_past_dft, -1);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
if (s_keep != 0) {
|
||||
llama_kv_cache_seq_cp(ctx_dft, s_keep + DOFFS, 0 + DOFFS, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_tgt, s_keep, 0, -1, -1);
|
||||
}
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_rm(ctx_dft, s + DOFFS, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
|
||||
}
|
||||
|
||||
/*
|
||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||
|
@ -212,22 +310,28 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].active = false;
|
||||
drafts[s].tokens.clear();
|
||||
drafts[s].tokens_p.clear();
|
||||
drafts[s].i_batch_tgt.clear();
|
||||
}
|
||||
// note: will be erased after the speculation phase
|
||||
drafts[0].tokens.push_back(id);
|
||||
drafts[0].tokens_p.push_back(0);
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
llama_batch_clear(batch_dft);
|
||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true);
|
||||
|
||||
LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode (ctx_dft, batch_dft);
|
||||
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
|
||||
|
||||
++n_past_dft;
|
||||
|
||||
|
@ -254,6 +358,10 @@ int main(int argc, char ** argv) {
|
|||
llama_batch_clear(batch_tgt);
|
||||
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
||||
|
||||
// double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0;
|
||||
LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n",
|
||||
avg_accepted, avg_rejected, min_accepted, max_rejected);
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
batch_dft.n_tokens = 0;
|
||||
|
@ -267,17 +375,24 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
|
||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
||||
save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
|
||||
|
||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
||||
if (cur_p[k].p < 1e-5f) continue;
|
||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||
}
|
||||
|
||||
if (cur_p[0].p < p_accept) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
||||
double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16
|
||||
? p_accept
|
||||
: std::max(double(min_accepted * 0.98), avg_accepted * 0.75f);
|
||||
// accept_threshold = 0.8;
|
||||
if (cur_p[0].p < accept_threshold) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold);
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
@ -286,11 +401,20 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// attempt to split the branch if the probability is high enough
|
||||
for (int f = 1; f < 8; ++f) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
// if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||
// if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) {
|
||||
double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16
|
||||
? p_split
|
||||
: ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4)
|
||||
* (n_seq_cur >= 2 ? 0.75 : 1.0) );
|
||||
// split_threshold = 0.1;
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) {
|
||||
n_split++;
|
||||
LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur,
|
||||
cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str());
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur + DOFFS, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s + DOFFS, n_seq_cur + DOFFS, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
|
@ -309,6 +433,7 @@ int main(int argc, char ** argv) {
|
|||
drafts[n_seq_cur].skip = true;
|
||||
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].tokens_p = drafts[s].tokens_p;
|
||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
|
@ -331,6 +456,7 @@ int main(int argc, char ** argv) {
|
|||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
drafts[s].tokens_p.push_back(cur_p[is].p);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
@ -340,7 +466,7 @@ int main(int argc, char ** argv) {
|
|||
// add the token to the batch for batched decoding with the draft model
|
||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||
|
||||
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||
llama_batch_add(batch_dft, id, n_past_cur, { s + DOFFS }, true);
|
||||
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
drafts[s].drafting = false;
|
||||
|
@ -352,9 +478,18 @@ int main(int argc, char ** argv) {
|
|||
if (batch_dft.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
// LOG("Draft eval: %d\n", batch_dft.n_tokens);
|
||||
// for (int x = 0; x < batch_dft.n_tokens; x++) {
|
||||
// LOG("* %03d: seq %3d, pos %4d, token %6d '%s'", x,
|
||||
// batch_dft.seq_id[x][0], batch_dft.pos[x],
|
||||
// batch_dft.token[x], llama_token_to_piece(ctx_dft, batch_dft.token[x]).c_str());
|
||||
// }
|
||||
|
||||
LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
|
||||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
|
@ -365,13 +500,17 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
// llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
|
||||
}
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
||||
LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
||||
|
@ -382,6 +521,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||
drafts[s].tokens_p.erase(drafts[s].tokens_p.begin());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -395,9 +535,13 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("\n");
|
||||
LOG_TEE("n_draft = %d\n", n_draft);
|
||||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
LOG_TEE("drafted = %.3f%%\n", 100.0f * n_drafted / n_predict);
|
||||
LOG_TEE("n_drafted = %d\n", n_drafted);
|
||||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
LOG_TEE("n_draft = %d\n", n_draft);
|
||||
LOG_TEE("n_split = %d\n", n_split);
|
||||
LOG_TEE("n_badsplit= %d\n", n_bad_split);
|
||||
|
||||
LOG_TEE("\ndraft:\n");
|
||||
llama_print_timings(ctx_dft);
|
||||
|
@ -415,8 +559,10 @@ int main(int argc, char ** argv) {
|
|||
llama_free(ctx_tgt);
|
||||
llama_free_model(model_tgt);
|
||||
|
||||
llama_free(ctx_dft);
|
||||
llama_free_model(model_dft);
|
||||
if (!self_speculation) {
|
||||
llama_free(ctx_dft);
|
||||
llama_free_model(model_dft);
|
||||
}
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue