Add heuristic algo for speculative
This commit is contained in:
parent
35195689cd
commit
98230ef656
1 changed files with 14 additions and 1 deletions
|
@ -84,7 +84,7 @@ int main(int argc, char ** argv) {
|
||||||
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
||||||
|
|
||||||
// how many tokens to draft each time
|
// how many tokens to draft each time
|
||||||
const int n_draft = params.n_draft;
|
int n_draft = params.n_draft;
|
||||||
|
|
||||||
int n_predict = 0;
|
int n_predict = 0;
|
||||||
int n_drafted = 0;
|
int n_drafted = 0;
|
||||||
|
@ -116,6 +116,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// sample from the drafted tokens if any
|
// sample from the drafted tokens if any
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
|
bool all_accepted = false;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
||||||
|
|
||||||
|
@ -141,6 +143,9 @@ int main(int argc, char ** argv) {
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
++i_dft;
|
++i_dft;
|
||||||
|
|
||||||
|
if (i_dft == (int) drafted.size()) {
|
||||||
|
all_accepted = true;
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,6 +159,14 @@ int main(int argc, char ** argv) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (drafted.size() > 0 && all_accepted) {
|
||||||
|
n_draft += 2;
|
||||||
|
LOG("all drafted tokens accepted, n_draft = %d\n", n_draft);
|
||||||
|
} else {
|
||||||
|
n_draft -= 1;
|
||||||
|
LOG("drafted token rejected, n_draft = %d\n", n_draft);
|
||||||
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if (n_predict > params.n_predict || has_eos) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue