What if we do something crazy like add layers instead of removing them?
This commit is contained in:
parent
d6f35c7ca5
commit
0abf0064ca
2 changed files with 218 additions and 48 deletions
|
@ -323,10 +323,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
|
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
|
||||||
|
|
||||||
const int32_t n_layers = 32; // model layer count
|
// model layer count
|
||||||
const int test_count = 6; // num perplexity chunks to run for each test
|
const int32_t n_layers = 32;
|
||||||
const size_t prune_target = 4; // prune this many of the worst results each pass
|
|
||||||
// end tunables
|
// num perplexity chunks to run for each test
|
||||||
|
const int test_count = 4;
|
||||||
|
|
||||||
|
// prune this many of the worst results each pass
|
||||||
|
const size_t prune_target = 2;
|
||||||
|
|
||||||
|
// start with all but first/last layers disabled and start adding them back
|
||||||
|
const bool anti_mode = true;
|
||||||
|
|
||||||
|
// **** end tunables ***
|
||||||
|
|
||||||
// 1 = attn, 2 = mlp, 3 = both
|
// 1 = attn, 2 = mlp, 3 = both
|
||||||
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
|
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
|
||||||
|
@ -340,11 +349,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
skip_types.resize(n_layers);
|
skip_types.resize(n_layers);
|
||||||
std::fill(skip_types.begin(), skip_types.end(), 0);
|
std::fill(skip_types.begin(), skip_types.end(), 0);
|
||||||
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
|
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
|
||||||
std::vector<int32_t> worsts;
|
std::vector<int32_t> extremes;
|
||||||
worsts.resize(n_layers);
|
extremes.resize(n_layers);
|
||||||
std::fill(worsts.begin(), worsts.end(), 0);
|
std::fill(extremes.begin(), extremes.end(), 0);
|
||||||
|
if (anti_mode) {
|
||||||
|
// No pointing in starting with first/last layer disabled.
|
||||||
|
skip_types[0] = 15;
|
||||||
|
skip_types[n_layers - 1] = 15;
|
||||||
|
skips.push_back(0); skips.push_back(0 + n_layers);
|
||||||
|
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
|
||||||
|
}
|
||||||
int32_t curr_best_layer = -1, curr_best_type = 0;
|
int32_t curr_best_layer = -1, curr_best_type = 0;
|
||||||
double curr_best_ppl = -1, ref_ppl = -1;
|
double curr_best_ppl = -1, ref_ppl = -1;
|
||||||
|
const int32_t mask = anti_mode ? 3 : 0;
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
@ -372,35 +389,40 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
}
|
}
|
||||||
if (skip_layer >= n_layers) {
|
if (skip_layer >= n_layers) {
|
||||||
if (curr_best_layer == -1) break;
|
if (curr_best_layer == -1) break;
|
||||||
if (pass_results.size() >= prune_target * 2) {
|
if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
|
||||||
std::sort(pass_results.begin(), pass_results.end(),
|
std::sort(pass_results.begin(), pass_results.end(),
|
||||||
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
|
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
|
||||||
|
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
|
||||||
return std::get<2>(a) > std::get<2>(b);
|
return std::get<2>(a) > std::get<2>(b);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
const size_t num_prune = std::min(pass_results.size(), prune_target);
|
const size_t num_prune = std::min(pass_results.size(), prune_target);
|
||||||
for (size_t temp = 0; temp < num_prune; temp++) {
|
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
|
||||||
int32_t lidx = std::get<0>(pass_results[temp]);
|
int32_t lidx = std::get<0>(pass_results[temp]);
|
||||||
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
|
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
|
||||||
worsts[lidx] |= std::get<1>(pass_results[temp]);
|
extremes[lidx] |= std::get<1>(pass_results[temp]);
|
||||||
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
|
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
|
||||||
|
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
|
||||||
|
if (anti_mode) {
|
||||||
|
skip_types[lidx] |= std::get<1>(pass_results[temp]);
|
||||||
|
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : -lidx);
|
||||||
|
}
|
||||||
|
if (++pruned >= num_prune) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pass_results.clear();
|
pass_results.clear();
|
||||||
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
|
printf("\n\nADD %c%3d - ppl vs ref %.4f",
|
||||||
int(label[curr_best_type]), curr_best_layer,
|
int(label[curr_best_type]), curr_best_layer,
|
||||||
curr_best_ppl - ref_ppl);
|
curr_best_ppl - ref_ppl);
|
||||||
if (curr_best_ppl > ref_ppl * 1.75) break;
|
if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
|
||||||
skip_types[curr_best_layer] += curr_best_type;
|
skip_types[curr_best_layer] += curr_best_type;
|
||||||
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
|
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
|
||||||
skips.push_back(curr_best_layer);
|
|
||||||
}
|
|
||||||
curr_best_layer = -1;
|
curr_best_layer = -1;
|
||||||
curr_best_ppl = -1;
|
curr_best_ppl = -1;
|
||||||
curr_best_type = 0;
|
curr_best_type = 0;
|
||||||
skip_layer = n_layers;
|
skip_layer = n_layers;
|
||||||
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||||
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2);
|
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
|
||||||
}
|
}
|
||||||
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||||
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
|
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
|
||||||
|
@ -420,16 +442,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
logit_history.clear();
|
logit_history.clear();
|
||||||
prob_history.clear();
|
prob_history.clear();
|
||||||
|
|
||||||
|
int alive = 0;
|
||||||
for (int32_t i = 0; i < n_layers; i++) {
|
for (int32_t i = 0; i < n_layers; i++) {
|
||||||
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
|
layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
|
||||||
|
alive += !(layers[i] & 1) + !(layers[i] & 2);
|
||||||
}
|
}
|
||||||
layers[n_layers] = -1;
|
layers[n_layers] = -1;
|
||||||
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
|
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
|
||||||
for (const auto l : skips) {
|
for (auto l : skips) {
|
||||||
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
|
printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
|
||||||
}
|
}
|
||||||
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
|
printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
|
||||||
skips.size() + 1,
|
alive, n_layers * 2,
|
||||||
int(label[curr_best_type]), curr_best_layer,
|
int(label[curr_best_type]), curr_best_layer,
|
||||||
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
|
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
|
||||||
test_t_total);
|
test_t_total);
|
||||||
|
@ -477,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0 && skip_layer < 0 && skips.empty()) {
|
if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
|
@ -516,7 +540,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) {
|
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
|
||||||
i = test_count - 1;
|
i = test_count - 1;
|
||||||
skip_types[skip_layer] |= test_skip_type << 2;
|
skip_types[skip_layer] |= test_skip_type << 2;
|
||||||
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
|
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#define DOFFS 10000
|
||||||
|
|
||||||
struct seq_draft {
|
struct seq_draft {
|
||||||
bool active = false;
|
bool active = false;
|
||||||
bool drafting = false;
|
bool drafting = false;
|
||||||
|
@ -17,10 +19,31 @@ struct seq_draft {
|
||||||
std::vector<int> i_batch_tgt;
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
std::vector<float> tokens_p;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling;
|
struct llama_sampling_context * ctx_sampling;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static void save_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
|
||||||
|
// printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
|
||||||
|
// printf("<S>");
|
||||||
|
GGML_ASSERT(doffs + count <= 30);
|
||||||
|
memcpy(
|
||||||
|
v.data() + doffs * n_vocab,
|
||||||
|
llama_get_logits(ctx) + soffs * n_vocab,
|
||||||
|
sizeof(float) * size_t(n_vocab) * count);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void restore_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
|
||||||
|
// printf("<R>");
|
||||||
|
// printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
|
||||||
|
GGML_ASSERT(soffs + count <= 30);
|
||||||
|
memcpy(
|
||||||
|
llama_get_logits(ctx) + doffs * n_vocab,
|
||||||
|
v.data() + soffs * n_vocab,
|
||||||
|
sizeof(float) * size_t(n_vocab) * count);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
@ -37,8 +60,10 @@ int main(int argc, char ** argv) {
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
const float p_accept = 0.80f;
|
// const float p_accept = 0.80f;
|
||||||
const float p_split = 0.10f;
|
// const float p_split = 0.10f;
|
||||||
|
const float p_accept = 0.5f; // 0.80f;
|
||||||
|
const float p_split = p_accept / 8; // 0.10f;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
|
@ -46,6 +71,8 @@ int main(int argc, char ** argv) {
|
||||||
log_dump_cmdline(argc, argv);
|
log_dump_cmdline(argc, argv);
|
||||||
#endif // LOG_DISABLE_LOGS
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
|
bool self_speculation = false;
|
||||||
|
|
||||||
// init llama.cpp
|
// init llama.cpp
|
||||||
llama_backend_init(params.numa);
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
@ -60,9 +87,18 @@ int main(int argc, char ** argv) {
|
||||||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
// load the draft model
|
// load the draft model
|
||||||
params.model = params.model_draft;
|
if (params.model != params.model_draft) {
|
||||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
params.model = params.model_draft;
|
||||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||||
|
} else {
|
||||||
|
self_speculation = true;
|
||||||
|
model_dft = model_tgt;
|
||||||
|
ctx_dft = ctx_tgt;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_ctx = llama_n_ctx(ctx_tgt);
|
||||||
|
const int n_vocab = llama_n_vocab(model_tgt);
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<llama_token> inp;
|
std::vector<llama_token> inp;
|
||||||
|
@ -84,14 +120,33 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
|
|
||||||
|
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||||
|
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
||||||
|
std::vector<float> logits_tgt, logits_dft;
|
||||||
|
|
||||||
const int n_input = inp.size();
|
const int n_input = inp.size();
|
||||||
|
|
||||||
const auto t_enc_start = ggml_time_us();
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
// eval the prompt with both models
|
// eval the prompt with both models
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
llama_batch_clear(batch_tgt);
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
logits_tgt.resize(n_vocab * 30);
|
||||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
|
logits_dft.resize(n_vocab * 30);
|
||||||
|
for (int i = 0; i < n_input - 1; i++) {
|
||||||
|
llama_batch_add(batch_tgt, inp[i], i, { 0 }, false);
|
||||||
|
}
|
||||||
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
|
llama_batch_clear(batch_tgt);
|
||||||
|
llama_batch_add(batch_tgt, inp.back(), n_input - 1, { 0 }, true);
|
||||||
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
|
save_logits(ctx_tgt, logits_tgt, n_vocab);
|
||||||
|
if (!self_speculation) {
|
||||||
|
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
|
||||||
|
} else {
|
||||||
|
// llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
|
||||||
|
llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1);
|
||||||
|
}
|
||||||
|
// save_logits(ctx_dft, logits_dft, n_vocab, n_input);
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
|
@ -104,6 +159,8 @@ int main(int argc, char ** argv) {
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
int n_accept = 0;
|
int n_accept = 0;
|
||||||
|
int n_split = 0;
|
||||||
|
int n_bad_split = 0;
|
||||||
|
|
||||||
int n_past_tgt = inp.size();
|
int n_past_tgt = inp.size();
|
||||||
int n_past_dft = inp.size();
|
int n_past_dft = inp.size();
|
||||||
|
@ -124,8 +181,16 @@ int main(int argc, char ** argv) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
// std::vector<int32_t> run_layers_dft = {
|
||||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
// 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
|
||||||
|
// 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
|
||||||
|
// 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0,
|
||||||
|
// 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, };
|
||||||
|
std::vector<int32_t> run_layers_dft = {
|
||||||
|
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1,
|
||||||
|
1, 0, 1, 0, 0, 0, -1, };
|
||||||
|
|
||||||
|
batch_dft.run_layers = run_layers_dft.data();
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
@ -133,7 +198,11 @@ int main(int argc, char ** argv) {
|
||||||
drafts[0].i_batch_tgt.resize(1);
|
drafts[0].i_batch_tgt.resize(1);
|
||||||
drafts[0].i_batch_tgt[0] = 0;
|
drafts[0].i_batch_tgt[0] = 0;
|
||||||
|
|
||||||
|
double avg_accepted = 0, avg_rejected = 0;
|
||||||
|
float min_accepted = 0, max_rejected = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
LOG("*** Draft start\n");
|
||||||
// print current draft sequences
|
// print current draft sequences
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
if (!drafts[s].active) {
|
if (!drafts[s].active) {
|
||||||
|
@ -152,9 +221,11 @@ int main(int argc, char ** argv) {
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
|
restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
||||||
|
save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
|
@ -179,11 +250,26 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
||||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
||||||
|
|
||||||
s_keep = s;
|
s_keep = s;
|
||||||
matches = true;
|
matches = true;
|
||||||
|
LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]);
|
||||||
|
if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft];
|
||||||
|
else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]);
|
||||||
|
avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1);
|
||||||
|
avg_accepted /= 2;
|
||||||
} else {
|
} else {
|
||||||
|
if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) {
|
||||||
|
if (i_dft == 0 && s > 0) n_bad_split++;
|
||||||
|
max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]);
|
||||||
|
avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1);
|
||||||
|
avg_rejected /= 2;
|
||||||
|
LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n",
|
||||||
|
s, i_dft, drafts[s].tokens[i_dft],
|
||||||
|
llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(),
|
||||||
|
id, token_str.c_str());
|
||||||
|
}
|
||||||
drafts[s].active = false;
|
drafts[s].active = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -204,6 +290,18 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm(ctx_dft, s_keep + DOFFS, n_past_dft, -1);
|
||||||
|
llama_kv_cache_seq_rm(ctx_tgt, s_keep, n_past_tgt, -1);
|
||||||
|
if (s_keep != 0) {
|
||||||
|
llama_kv_cache_seq_cp(ctx_dft, s_keep + DOFFS, 0 + DOFFS, -1, -1);
|
||||||
|
llama_kv_cache_seq_cp(ctx_tgt, s_keep, 0, -1, -1);
|
||||||
|
}
|
||||||
|
for (int s = 1; s < n_seq_dft; ++s) {
|
||||||
|
llama_kv_cache_seq_rm(ctx_dft, s + DOFFS, -1, -1);
|
||||||
|
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||||
|
@ -212,22 +310,28 @@ int main(int argc, char ** argv) {
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].active = false;
|
drafts[s].active = false;
|
||||||
drafts[s].tokens.clear();
|
drafts[s].tokens.clear();
|
||||||
|
drafts[s].tokens_p.clear();
|
||||||
drafts[s].i_batch_tgt.clear();
|
drafts[s].i_batch_tgt.clear();
|
||||||
}
|
}
|
||||||
// note: will be erased after the speculation phase
|
// note: will be erased after the speculation phase
|
||||||
drafts[0].tokens.push_back(id);
|
drafts[0].tokens.push_back(id);
|
||||||
|
drafts[0].tokens_p.push_back(0);
|
||||||
drafts[0].i_batch_tgt.push_back(0);
|
drafts[0].i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
llama_batch_clear(batch_dft);
|
llama_batch_clear(batch_dft);
|
||||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true);
|
||||||
|
|
||||||
|
LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
||||||
llama_decode (ctx_dft, batch_dft);
|
llama_decode (ctx_dft, batch_dft);
|
||||||
|
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
|
||||||
|
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
|
@ -254,6 +358,10 @@ int main(int argc, char ** argv) {
|
||||||
llama_batch_clear(batch_tgt);
|
llama_batch_clear(batch_tgt);
|
||||||
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
||||||
|
|
||||||
|
// double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0;
|
||||||
|
LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n",
|
||||||
|
avg_accepted, avg_rejected, min_accepted, max_rejected);
|
||||||
|
|
||||||
// sample n_draft tokens from the draft model using tree-based sampling
|
// sample n_draft tokens from the draft model using tree-based sampling
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
batch_dft.n_tokens = 0;
|
batch_dft.n_tokens = 0;
|
||||||
|
@ -267,17 +375,24 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
|
||||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
||||||
|
save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
|
||||||
|
|
||||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||||
|
|
||||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
||||||
|
if (cur_p[k].p < 1e-5f) continue;
|
||||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cur_p[0].p < p_accept) {
|
double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16
|
||||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
? p_accept
|
||||||
|
: std::max(double(min_accepted * 0.98), avg_accepted * 0.75f);
|
||||||
|
// accept_threshold = 0.8;
|
||||||
|
if (cur_p[0].p < accept_threshold) {
|
||||||
|
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold);
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -286,11 +401,20 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// attempt to split the branch if the probability is high enough
|
// attempt to split the branch if the probability is high enough
|
||||||
for (int f = 1; f < 8; ++f) {
|
for (int f = 1; f < 8; ++f) {
|
||||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
// if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
// if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) {
|
||||||
|
double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16
|
||||||
|
? p_split
|
||||||
|
: ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4)
|
||||||
|
* (n_seq_cur >= 2 ? 0.75 : 1.0) );
|
||||||
|
// split_threshold = 0.1;
|
||||||
|
if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) {
|
||||||
|
n_split++;
|
||||||
|
LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur,
|
||||||
|
cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str());
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur + DOFFS, -1, -1);
|
||||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
llama_kv_cache_seq_cp(ctx_dft, s + DOFFS, n_seq_cur + DOFFS, -1, -1);
|
||||||
|
|
||||||
// all previous tokens from this branch are now also part of the new branch
|
// all previous tokens from this branch are now also part of the new branch
|
||||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||||
|
@ -309,6 +433,7 @@ int main(int argc, char ** argv) {
|
||||||
drafts[n_seq_cur].skip = true;
|
drafts[n_seq_cur].skip = true;
|
||||||
|
|
||||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||||
|
drafts[n_seq_cur].tokens_p = drafts[s].tokens_p;
|
||||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
|
@ -331,6 +456,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
|
drafts[s].tokens_p.push_back(cur_p[is].p);
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
|
@ -340,7 +466,7 @@ int main(int argc, char ** argv) {
|
||||||
// add the token to the batch for batched decoding with the draft model
|
// add the token to the batch for batched decoding with the draft model
|
||||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||||
|
|
||||||
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
llama_batch_add(batch_dft, id, n_past_cur, { s + DOFFS }, true);
|
||||||
|
|
||||||
if (batch_tgt.n_tokens > n_draft) {
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
|
@ -352,9 +478,18 @@ int main(int argc, char ** argv) {
|
||||||
if (batch_dft.n_tokens == 0) {
|
if (batch_dft.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
// LOG("Draft eval: %d\n", batch_dft.n_tokens);
|
||||||
|
// for (int x = 0; x < batch_dft.n_tokens; x++) {
|
||||||
|
// LOG("* %03d: seq %3d, pos %4d, token %6d '%s'", x,
|
||||||
|
// batch_dft.seq_id[x][0], batch_dft.pos[x],
|
||||||
|
// batch_dft.token[x], llama_token_to_piece(ctx_dft, batch_dft.token[x]).c_str());
|
||||||
|
// }
|
||||||
|
|
||||||
|
LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
|
||||||
|
|
||||||
// evaluate the drafted tokens on the draft model
|
// evaluate the drafted tokens on the draft model
|
||||||
llama_decode(ctx_dft, batch_dft);
|
llama_decode(ctx_dft, batch_dft);
|
||||||
|
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
|
||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
++n_drafted;
|
++n_drafted;
|
||||||
|
|
||||||
|
@ -365,13 +500,17 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
// llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
for (int s = 1; s < n_seq_dft; ++s) {
|
||||||
|
llama_kv_cache_seq_rm(ctx_tgt, s, -1, -1);
|
||||||
|
}
|
||||||
for (int s = 1; s < n_seq_dft; ++s) {
|
for (int s = 1; s < n_seq_dft; ++s) {
|
||||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||||
llama_decode(ctx_tgt, batch_tgt);
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
|
save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens);
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -382,6 +521,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||||
|
drafts[s].tokens_p.erase(drafts[s].tokens_p.begin());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,9 +535,13 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
LOG_TEE("n_draft = %d\n", n_draft);
|
LOG_TEE("n_draft = %d\n", n_draft);
|
||||||
LOG_TEE("n_predict = %d\n", n_predict);
|
LOG_TEE("n_predict = %d\n", n_predict);
|
||||||
|
LOG_TEE("drafted = %.3f%%\n", 100.0f * n_drafted / n_predict);
|
||||||
LOG_TEE("n_drafted = %d\n", n_drafted);
|
LOG_TEE("n_drafted = %d\n", n_drafted);
|
||||||
LOG_TEE("n_accept = %d\n", n_accept);
|
LOG_TEE("n_accept = %d\n", n_accept);
|
||||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||||
|
LOG_TEE("n_draft = %d\n", n_draft);
|
||||||
|
LOG_TEE("n_split = %d\n", n_split);
|
||||||
|
LOG_TEE("n_badsplit= %d\n", n_bad_split);
|
||||||
|
|
||||||
LOG_TEE("\ndraft:\n");
|
LOG_TEE("\ndraft:\n");
|
||||||
llama_print_timings(ctx_dft);
|
llama_print_timings(ctx_dft);
|
||||||
|
@ -415,8 +559,10 @@ int main(int argc, char ** argv) {
|
||||||
llama_free(ctx_tgt);
|
llama_free(ctx_tgt);
|
||||||
llama_free_model(model_tgt);
|
llama_free_model(model_tgt);
|
||||||
|
|
||||||
llama_free(ctx_dft);
|
if (!self_speculation) {
|
||||||
llama_free_model(model_dft);
|
llama_free(ctx_dft);
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
}
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue