Add heuristic algo for speculative
This commit is contained in:
parent
35195689cd
commit
98230ef656
1 changed files with 14 additions and 1 deletions
|
@ -84,7 +84,7 @@ int main(int argc, char ** argv) {
|
|||
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
||||
|
||||
// how many tokens to draft each time
|
||||
const int n_draft = params.n_draft;
|
||||
int n_draft = params.n_draft;
|
||||
|
||||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
|
@ -116,6 +116,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// sample from the drafted tokens if any
|
||||
int i_dft = 0;
|
||||
bool all_accepted = false;
|
||||
|
||||
while (true) {
|
||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
||||
|
||||
|
@ -141,6 +143,9 @@ int main(int argc, char ** argv) {
|
|||
++n_past_dft;
|
||||
++i_dft;
|
||||
|
||||
if (i_dft == (int) drafted.size()) {
|
||||
all_accepted = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -154,6 +159,14 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
if (drafted.size() > 0 && all_accepted) {
|
||||
n_draft += 2;
|
||||
LOG("all drafted tokens accepted, n_draft = %d\n", n_draft);
|
||||
} else {
|
||||
n_draft -= 1;
|
||||
LOG("drafted token rejected, n_draft = %d\n", n_draft);
|
||||
}
|
||||
|
||||
if (n_predict > params.n_predict || has_eos) {
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue