speculative : add tree-based sampling support

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-14 17:54:02 +03:00
parent 5261aee8d8
commit 4de5a2d473
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 469 additions and 192 deletions

View file

@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_batch batch = llama_batch_init(n_kv_max, 0);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
@ -146,10 +147,11 @@ int main(int argc, char ** argv) {
batch.n_tokens = 16;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
batch.token[i] = 0;
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@ -177,10 +179,11 @@ int main(int argc, char ** argv) {
batch.n_tokens = is_pp_shared ? pp : pl*pp;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
batch.token[i] = 0;
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = true;
@ -207,10 +210,11 @@ int main(int argc, char ** argv) {
batch.n_tokens = pl;
for (int j = 0; j < pl; ++j) {
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.seq_id[j] = j;
batch.logits[j] = true;
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.n_seq_id[j] = 1;
batch.seq_id[j][0] = j;
batch.logits[j] = true;
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {