Sync latest changes

This commit is contained in:
KerfuffleV2 2023-10-23 02:40:37 -06:00
parent 8a569cfee5
commit 13e08d0efa
2 changed files with 258 additions and 54 deletions

View file

@ -397,6 +397,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
);
const size_t num_prune = std::min(pass_results.size(), prune_target);
if (num_prune > 0) printf("\nPruning: ");
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
int32_t lidx = std::get<0>(pass_results[temp]);
if (anti_mode) {
@ -405,17 +406,17 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
extremes[lidx] |= std::get<1>(pass_results[temp]);
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
printf("[%zu: %d (%d) - %.2f], ", pruned + 1, lidx,
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
if (++pruned >= num_prune) break;
}
}
pass_results.clear();
printf("\n\nADD %c%3d - ppl vs ref %.4f",
printf("\n\nADD %c%3d - ppl vs ref %.4f - cur:[",
int(label[curr_best_type]), curr_best_layer,
curr_best_ppl - ref_ppl);
if (!anti_mode) {
if (curr_best_ppl > ref_ppl * 1.75) break;
// if (curr_best_ppl > ref_ppl * 1.75) break;
skip_types[curr_best_layer] += curr_best_type;
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
}
@ -426,6 +427,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
}
for (int32_t i = 0; i < n_layers; i++) {
const int val = mask ^ (skip_types[i] & 3);
printf("%d%s", val, i < n_layers - 1 ? ", " : "]");
}
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
// printf("||%d, %d\n", new_sl, curr_skipped);

View file

@ -15,6 +15,8 @@ struct seq_draft {
bool drafting = false;
bool skip = false;
int split_pos = 0;
int i_batch_dft = 0;
std::vector<int> i_batch_tgt;
@ -27,7 +29,7 @@ struct seq_draft {
static void save_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
// printf("SAVE %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
// printf("<S>");
GGML_ASSERT(doffs + count <= 30);
GGML_ASSERT(doffs + count < 64);
memcpy(
v.data() + doffs * n_vocab,
llama_get_logits(ctx) + soffs * n_vocab,
@ -37,13 +39,47 @@ static void save_logits(llama_context * ctx, std::vector<float> & v, const int n
static void restore_logits(llama_context * ctx, std::vector<float> & v, const int n_vocab, const int count = 1, const int soffs = 0, const int doffs = 0) {
// printf("<R>");
// printf("REST %p: %d, %d, %d\n", (void *)ctx, count, soffs, doffs);
GGML_ASSERT(soffs + count <= 30);
GGML_ASSERT(soffs + count < 64);
memcpy(
llama_get_logits(ctx) + doffs * n_vocab,
v.data() + soffs * n_vocab,
sizeof(float) * size_t(n_vocab) * count);
}
static llama_token_data_array normalize_candidates(const float * logits, const int n_vocab, std::vector<llama_token_data> & cur) {
cur.reserve(n_vocab);
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
llama_sample_top_k(NULL, &cur_p, 100, 1);
llama_sample_softmax(NULL, &cur_p);
cur.resize(cur_p.size);
return cur_p;
}
static int32_t find_normalized(const llama_token_data_array & tda, const llama_token id) {
llama_token_data *item = tda.data;
for (int32_t i = 0; i < tda.size; i++, item++)
if (item->id == id) return i;
return -1;
}
static double running_average(double & cur, double val, double n = 20) {
if (cur < 1e-5f) {
cur = val;
return cur;
}
// New average = old average * (n-1)/n + new value /n
cur = cur * (n - 1) / n + val / n;
return cur;
}
int main(int argc, char ** argv) {
gpt_params params;
@ -62,8 +98,8 @@ int main(int argc, char ** argv) {
// TODO: make this configurable
// const float p_accept = 0.80f;
// const float p_split = 0.10f;
const float p_accept = 0.5f; // 0.80f;
const float p_split = p_accept / 8; // 0.10f;
const float p_accept = 0.75f; // 0.80f;
const float p_split = 0.6f; // p_accept / 8; // 0.10f;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
@ -130,8 +166,8 @@ int main(int argc, char ** argv) {
// eval the prompt with both models
llama_batch_clear(batch_tgt);
logits_tgt.resize(n_vocab * 30);
logits_dft.resize(n_vocab * 30);
logits_tgt.resize(n_vocab * 64);
logits_dft.resize(n_vocab * 64);
for (int i = 0; i < n_input - 1; i++) {
llama_batch_add(batch_tgt, inp[i], i, { 0 }, false);
}
@ -146,7 +182,7 @@ int main(int argc, char ** argv) {
// llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0 + DOFFS));
llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, -1);
}
// save_logits(ctx_dft, logits_dft, n_vocab, n_input);
save_logits(ctx_dft, logits_dft, n_vocab, n_input);
const auto t_enc_end = ggml_time_us();
@ -161,6 +197,11 @@ int main(int argc, char ** argv) {
int n_accept = 0;
int n_split = 0;
int n_bad_split = 0;
int n_dup_split = 0;
int n_eff_split = 0;
int max_streak = 0;
int64_t t_dft_sample = 0, t_dft_gen = 0, t_dft_accept = 0, t_tgt_predict = 0;
int n_past_tgt = inp.size();
int n_past_dft = inp.size();
@ -170,26 +211,35 @@ int main(int argc, char ** argv) {
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
struct llama_sampling_context * ctx_dft_sampling = llama_sampling_init(params.sparams);
std::vector<llama_token_data> normalized_candidates;
normalized_candidates.reserve(n_vocab);
llama_token_data_array normalized_p;
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sparams.temp = std::max(0.01f, params.sparams.temp);
// params.sparams.temp = std::max(0.01f, params.sparams.temp);
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
}
// std::vector<int32_t> run_layers_dft = {
// 0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
// 3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
// 0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0,
// 3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, };
// 70B (80 layers) skips example
std::vector<int32_t> run_layers_dft = {
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 1, 0, 1, 1, 0, 2, 0, 1, 1,
1, 0, 1, 0, 0, 0, -1, };
0, 0, 2, 0, 2, 0, 0, 0, 0, 2, 3, 1, 0, 3, 3, 0, 3, 0, 1, 1,
3, 3, 3, 0, 2, 3, 2, 3, 3, 3, 1, 3, 0, 0, 2, 1, 0, 2, 0, 0,
0, 3, 0, 1, 0, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 1, 3, 3, 0,
3, 1, 3, 3, 0, 1, 3, 3, 3, 1, 3, 0, 0, 0, 1, 1, 2, 0, 1, 1, -1, };
// 3B (26 layers) skips example
// std::vector<int32_t> run_layers_dft = {
// 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 3, 0, 2, 3, 3, 1, 0, 2, 0, 1, 1, 2, 0, 0,
// // 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 2, 1, 3, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 2, 0, 1,
// -1, };
// NOTE: Comment this line out to disable skipping.
batch_dft.run_layers = run_layers_dft.data();
const auto t_dec_start = ggml_time_us();
@ -198,8 +248,13 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0;
double avg_accepted = 0, avg_rejected = 0;
float min_accepted = 0, max_rejected = 0;
double avg_accepted = 0, avg_rejected = 0, tgt_avg_accepted = 0;
double avg_accept_delta = 0;
float min_accepted = 0, max_rejected = 0, tgt_min_accepted = 0;
int64_t t_cur;
std::vector<std::vector<std::vector<llama_token_data>>> doubt;
while (true) {
LOG("*** Draft start\n");
@ -217,15 +272,37 @@ int main(int argc, char ** argv) {
int i_dft = 0;
int s_keep = 0;
float tgt_last_norm = 0, tgt_last_best_norm = 0, tgt_last_orig = 0;
while (true) {
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model
restore_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
normalized_p = normalize_candidates(llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]), n_vocab, normalized_candidates);
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
int32_t norm_pos = find_normalized(normalized_p, id);
int32_t orig_pos = find_normalized({ctx_sampling->cur.data(), ctx_sampling->cur.size(), false}, id);
if (norm_pos >= 0) {
tgt_last_norm = normalized_candidates[norm_pos].p;
tgt_last_best_norm = normalized_candidates[0].p;
running_average(tgt_avg_accepted, tgt_last_norm);
tgt_min_accepted = tgt_min_accepted < 1e-4
? tgt_last_norm
: std::min(tgt_min_accepted, tgt_last_norm);
} else {
tgt_last_norm = tgt_last_best_norm = tgt_avg_accepted;
}
if (orig_pos >= 0) {
tgt_last_orig = ctx_sampling->cur[orig_pos].p;
}
LOG("target sampled (%d, '%s') orig_p=%5.4f, norm_p=%5.4f\n",
id, llama_token_to_piece(ctx_tgt, id).c_str(),
orig_pos >= 0 ? ctx_sampling->cur[orig_pos].p : -1,
norm_pos >= 0 ? normalized_candidates[norm_pos].p : -1);
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
save_logits(ctx_tgt, logits_tgt, n_vocab, 1, drafts[s_keep].i_batch_tgt[i_dft], drafts[s_keep].i_batch_tgt[i_dft]);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
@ -245,26 +322,30 @@ int main(int argc, char ** argv) {
bool matches = false;
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
if (!drafts[s].active || i_dft < drafts[s].split_pos) {
continue;
}
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
LOG("the sampled target token matches drafted token %d of sequence %d (%d, '%s') - accepted\n",
i_dft, s, id, token_str.c_str());
if (i_dft == 0 && s > 0) {
if (matches) n_dup_split++;
else n_eff_split++;
}
s_keep = s;
matches = true;
LOG("Derp[%d]: %6d (%5.4f)\n", s, drafts[s].tokens[i_dft], drafts[s].tokens_p[i_dft]);
if (min_accepted == 0) min_accepted = drafts[s].tokens_p[i_dft];
else min_accepted = std::min(min_accepted, drafts[s].tokens_p[i_dft]);
avg_accepted += drafts[s].tokens_p[i_dft] * (avg_accepted == 0 ? 2 : 1);
avg_accepted /= 2;
running_average(avg_accepted, drafts[s].tokens_p[i_dft]);
running_average(avg_accept_delta, tgt_last_norm - drafts[s].tokens_p[i_dft]);
} else {
if (i_dft < (int) drafts[s].tokens.size() && id != drafts[s].tokens[i_dft]) {
if (i_dft == 0 && s > 0) n_bad_split++;
max_rejected = std::max(max_rejected, drafts[s].tokens_p[i_dft]);
avg_rejected += drafts[s].tokens_p[i_dft] * (avg_rejected == 0 ? 2 : 1);
avg_rejected /= 2;
running_average(avg_rejected, drafts[s].tokens_p[i_dft]);
LOG("-- Terminate sequence %d+%d: (%d, '%s') != target (%d, '%s') - rejected\n",
s, i_dft, drafts[s].tokens[i_dft],
llama_token_to_piece(ctx_dft, drafts[s].tokens[i_dft]).c_str(),
@ -279,8 +360,27 @@ int main(int argc, char ** argv) {
++n_past_tgt;
++n_past_dft;
++i_dft;
max_streak = std::max(max_streak, i_dft);
continue;
} else {
for (size_t seqnum = 0; seqnum < doubt.size(); seqnum++) {
const std::vector<std::vector<llama_token_data>> & sdoubt = doubt[seqnum];
if (sdoubt.size() <= i_dft) continue;
const std::vector<llama_token_data> & sidoubt = sdoubt[i_dft];
for (size_t cidx = 0; cidx < sidoubt.size(); cidx++) {
if (sidoubt[cidx].id == id) {
LOG("Shoulda picked seq %3zu, pos %4d, candidate %2zu @ p %5.4f: %6d '%s'\n",
seqnum, i_dft, cidx, sidoubt[cidx].p,
id, token_str.c_str());
running_average(avg_accepted, sidoubt[cidx].p);
if (cidx < 2) {
running_average(avg_accept_delta, tgt_last_norm - sidoubt[cidx].p);
min_accepted = min_accepted < 1e-5f ? sidoubt[cidx].p : std::min(min_accepted, sidoubt[cidx].p);
}
break;
}
}
}
}
}
@ -315,6 +415,7 @@ int main(int argc, char ** argv) {
}
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].split_pos = 0;
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].tokens_p.clear();
@ -327,10 +428,18 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 + DOFFS }, true);
if (self_speculation) {
// Copy KV items from non-brain-damaged model... Doesn't seem to help.
llama_kv_cache_seq_rm(ctx_dft, 0 + DOFFS, 0, n_past_dft - 2);
llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, 0, n_past_dft - 2);
// llama_kv_cache_seq_rm(ctx_dft, 0 + DOFFS, n_past_dft - 1, -1);
// llama_kv_cache_seq_cp(ctx_dft, 0, 0 + DOFFS, n_past_dft - 1, -1);
}
LOG("=== EVAL: DRAFT ACCEPTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
t_cur = ggml_time_us();
llama_decode (ctx_dft, batch_dft);
t_dft_accept += ggml_time_us() - t_cur;
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
++n_past_dft;
@ -358,9 +467,14 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
avg_rejected = std::max(0.05, std::min(avg_accepted - 0.05, avg_rejected));
avg_accepted = std::max(0.05, std::max(avg_rejected + 0.05, avg_accepted));
// double avg_accepted = n_accept > 0 ? avg_accepted / double(n_accept) : 0;
LOG("Average accepted/rejected: %3.5f / %3.5f -- Min accepted/max rejected: %3.5f / %3.5f\n",
avg_accepted, avg_rejected, min_accepted, max_rejected);
LOG("STATS: Avg tacc/dacc/drej: %3.5f / %3.5f / %3.5f | Min dacc/min tacc/max drej: %3.5f / %3.5f / %3.5f | delta %3.5f | max streak %d | n_dft/pred/acc: %d / %d / %d\n",
tgt_avg_accepted, avg_accepted, avg_rejected, min_accepted, tgt_min_accepted, max_rejected, avg_accept_delta, max_streak,
n_drafted, n_predict, n_accept);
doubt.clear();
doubt.resize(n_seq_dft);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
@ -371,43 +485,116 @@ int main(int argc, char ** argv) {
}
for (int s = 0; s < n_seq_dft; ++s) {
double accept_threshold, split_threshold;
if (!drafts[s].drafting || drafts[s].skip) {
continue;
}
doubt[s].push_back({});
restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
save_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft, drafts[s].i_batch_dft);
if (avg_rejected == 0 || avg_rejected == 0 || n_drafted + n_predict < 6) {
accept_threshold = std::max(0.6f, tgt_last_norm);
} else {
const auto & cur_p = drafts[s].ctx_sampling->cur;
accept_threshold = (tgt_avg_accepted - avg_accept_delta) * 0.3;
accept_threshold *= std::min(0.8, std::max(0.1, double(tgt_last_norm * 1.0)));
accept_threshold = std::max(double(min_accepted) * 1.1, accept_threshold);
accept_threshold = std::max(std::max(avg_accepted * 0.9, avg_rejected * 1.1), accept_threshold);
accept_threshold += 1.0 - (1.2 * n_accept / n_drafted);
accept_threshold *= (1.3 - (std::min(n_seq_cur + i, 6) * 0.1));
//
// accept_threshold = (tgt_avg_accepted - avg_accept_delta) * 0.3;
// accept_threshold *= std::min(0.8, std::max(0.1, double(tgt_last_norm * 1.0)));
// accept_threshold = std::max(double(min_accepted) * 1.1, accept_threshold);
// accept_threshold = std::max(std::max(avg_accepted * 0.9, avg_rejected * 1.1), accept_threshold);
// accept_threshold += 1.0 - (1.2 * n_accept / n_drafted);
// accept_threshold *= (0.7 + (std::min(n_seq_cur + i, 5) * 0.1));
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
if (cur_p[k].p < 1e-5f) continue;
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
}
double accept_threshold = avg_rejected == 0 || avg_rejected == 0 || n_drafted < 16
? p_accept
: std::max(double(min_accepted * 0.98), avg_accepted * 0.75f);
// accept_threshold = 0.8;
if (cur_p[0].p < accept_threshold) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, accept_threshold);
std::vector<llama_token_data> cur_p;
{
llama_token d_id;
std::vector<llama_token> already_picked;
float * logits = NULL;
t_cur = ggml_time_us();
for (int cidx = 0; cidx < 9; cidx++) {
llama_sampling_cp(drafts[s].ctx_sampling, ctx_dft_sampling);
restore_logits(ctx_dft, logits_dft, n_vocab, 1, drafts[s].i_batch_dft);
logits = llama_get_logits(ctx_dft);
normalized_p = normalize_candidates(logits, n_vocab, normalized_candidates);
for (size_t x = 0; x < std::min(normalized_p.size, size_t(10)); x++)
doubt[s].back().push_back(normalized_p.data[x]);
for (const auto & tid : already_picked)
logits[tid] = std::numeric_limits<float>::infinity() * -1;
d_id = llama_sampling_sample(ctx_dft_sampling, ctx_dft, NULL);
already_picked.push_back(d_id);
int32_t norm_pos = find_normalized(normalized_p, d_id);
if (norm_pos < 0) continue;
llama_token_data norm = normalized_candidates[norm_pos];
if (norm.p < 0.2) continue;
if (ctx_dft_sampling->params.temp <= 0) {
llama_token_data_array tda = { ctx_dft_sampling->cur.data(), ctx_dft_sampling->cur.size(), false };
llama_sample_top_k(ctx_dft, &tda, 100, 1);
llama_sample_softmax(ctx_dft, &tda);
ctx_dft_sampling->cur.resize(tda.size);
}
llama_token_data found;
found.id = -1;
for (const llama_token_data & td : ctx_dft_sampling->cur) {
if (td.id == d_id) {
found = td;
break;
}
}
GGML_ASSERT(found.id != -1);
LOG(" ** draft candidate %3d for seq %3d, pos %3d: %6d (%4.3f, norm %4.3f) '%s'\n",
cidx, s, i, found.id, found.p, norm_pos >= 0 ? normalized_candidates[norm_pos].p : -1,
llama_token_to_piece(ctx_dft, found.id).c_str());
if (found.p < 0.3) continue;
if (norm.p < 1e-2f) break;
cur_p.push_back(normalized_candidates[norm_pos]);
}
if (cur_p.size() > 1) {
std::sort(cur_p.begin() + 1, cur_p.end(),
[](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
}
);
}
}
t_dft_sample += ggml_time_us() - t_cur;
if (cur_p.empty()) {
LOG("stopping drafting for seq %3d, no viable candidates (%5.3f) \n", s, accept_threshold);
drafts[s].drafting = false;
continue;
} else if (cur_p[0].p < accept_threshold && (cur_p[0].p + (cur_p.size() < 2 ? 0 : cur_p[1].p)) < accept_threshold * 1.3) {
LOG("stopping drafting for seq %3d, pos %3d - probability too low: %.3f < %.3f\n", s, i, cur_p[0].p, accept_threshold);
drafts[s].drafting = false;
continue;
}
if (cur_p[0].p < accept_threshold) {
split_threshold = 0.0;
} else {
split_threshold = cur_p[0].p / 10.0;
// split_threshold = std::max(0.01, cur_p[0].p * (n_seq_cur + i > 1 ? 0.15 : 0.2));
}
std::vector<int> sa(1, s);
// LOG("Check splits: %zu\n", cur_p.size());
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
// if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
// if (n_seq_cur < n_seq_dft && cur_p[f].p > cur_p[0].p / 5) {
double split_threshold = avg_accepted == 0 || avg_rejected == 0 || n_drafted < 16
? p_split
: ( std::max(double(min_accepted * 0.7), avg_accepted * 0.4)
* (n_seq_cur >= 2 ? 0.75 : 1.0) );
// split_threshold = 0.1;
for (int f = 1; f < std::min(8, int(cur_p.size()) - 1); ++f) {
if (n_seq_cur < n_seq_dft && cur_p[f].p >= split_threshold) {
n_split++;
LOG(">>>%d<<< splitting seq %3d into %3d on %6d (%8.3f) '%s'\n", f, s, n_seq_cur,
@ -428,6 +615,7 @@ int main(int argc, char ** argv) {
}
// copy the draft state
drafts[n_seq_cur].split_pos = i;
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true;
@ -443,6 +631,8 @@ int main(int argc, char ** argv) {
n_seq_cur++;
} else {
LOG("Not splitting seq %3d into %3d, choice %2d @ %6d (%8.3f) '%s'\n", s, n_seq_cur, f,
cur_p[f].id, cur_p[f].p, llama_token_to_piece(ctx_dft, cur_p[f].id).c_str());
break;
}
}
@ -488,7 +678,9 @@ int main(int argc, char ** argv) {
LOG("=== EVAL: DRAFTED ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_dft).c_str());
// evaluate the drafted tokens on the draft model
t_cur = ggml_time_us();
llama_decode(ctx_dft, batch_dft);
t_dft_gen += ggml_time_us() - t_cur;
save_logits(ctx_dft, logits_dft, n_vocab, batch_dft.n_tokens);
++n_past_cur;
++n_drafted;
@ -509,7 +701,9 @@ int main(int argc, char ** argv) {
}
LOG("=== EVAL: TARGET ===: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
t_cur = ggml_time_us();
llama_decode(ctx_tgt, batch_tgt);
t_tgt_predict += ggml_time_us() - t_cur;
save_logits(ctx_tgt, logits_tgt, n_vocab, batch_tgt.n_tokens);
++n_past_tgt;
}
@ -531,7 +725,9 @@ int main(int argc, char ** argv) {
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("times: target predict: %5.3f, draft gen/accept/sample: %5.3f / %5.3f / %5.3f\n",
t_tgt_predict / 1e6f, t_dft_gen / 1e6f, t_dft_accept / 1e6f, t_dft_sample / 1e6f);
// int64_t t_dft_sample = 0, t_dft_gen = 0, t_dft_accept = 0, t_tgt_predict = 0;
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
@ -541,7 +737,10 @@ int main(int argc, char ** argv) {
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_split = %d\n", n_split);
LOG_TEE("n_effsplit= %d\n", n_eff_split);
LOG_TEE("n_badsplit= %d\n", n_bad_split);
LOG_TEE("n_dupsplit= %d\n", n_dup_split);
LOG_TEE("max streak= %d\n", max_streak);
LOG_TEE("\ndraft:\n");
llama_print_timings(ctx_dft);