speculative : improve heuristic impl
This commit is contained in:
parent
9248528d6e
commit
dddd784c4d
1 changed files with 21 additions and 12 deletions
|
@ -116,7 +116,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// sample from the drafted tokens if any
|
// sample from the drafted tokens if any
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
bool all_accepted = false;
|
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
||||||
|
@ -143,9 +142,6 @@ int main(int argc, char ** argv) {
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
++i_dft;
|
++i_dft;
|
||||||
|
|
||||||
if (i_dft == (int) drafted.size()) {
|
|
||||||
all_accepted = true;
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,20 +149,33 @@ int main(int argc, char ** argv) {
|
||||||
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
|
// heuristic for n_draft
|
||||||
|
{
|
||||||
|
const int n_dradt_cur = (int) drafted.size();
|
||||||
|
const bool all_accepted = i_dft == n_dradt_cur;
|
||||||
|
|
||||||
|
LOG("n_draft = %d\n", n_draft);
|
||||||
|
LOG("n_draft_cur = %d\n", n_dradt_cur);
|
||||||
|
LOG("i_dft = %d\n", i_dft);
|
||||||
|
LOG("all_accepted = %d\n", all_accepted);
|
||||||
|
|
||||||
|
if (all_accepted && n_draft == n_dradt_cur) {
|
||||||
|
LOG(" - max drafted tokens accepted - n_draft += 2\n");
|
||||||
|
n_draft += 2;
|
||||||
|
} else if (all_accepted) {
|
||||||
|
LOG(" - partially drafted tokens accepted - no change\n");
|
||||||
|
} else {
|
||||||
|
LOG(" - drafted token rejected - n_draft -= 1\n");
|
||||||
|
n_draft = std::max(2, n_draft - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
drafted.clear();
|
drafted.clear();
|
||||||
drafted.push_back(id);
|
drafted.push_back(id);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (drafted.size() > 0 && all_accepted) {
|
|
||||||
n_draft += 2;
|
|
||||||
LOG("all drafted tokens accepted, n_draft = %d\n", n_draft);
|
|
||||||
} else {
|
|
||||||
n_draft = std::max(2, n_draft - 1);
|
|
||||||
LOG("drafted token rejected, n_draft = %d\n", n_draft);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if (n_predict > params.n_predict || has_eos) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue