speculative : add tree-based sampling support
ggml-ci
This commit is contained in:
parent
5261aee8d8
commit
4de5a2d473
11 changed files with 469 additions and 192 deletions
|
@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0);
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
|
@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
|
@ -146,10 +147,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = 16;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
|
@ -177,10 +179,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = is_pp_shared ? pp : pl*pp;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -207,10 +210,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = pl;
|
||||
|
||||
for (int j = 0; j < pl; ++j) {
|
||||
batch.token[j] = 0;
|
||||
batch.pos[j] = pp + i;
|
||||
batch.seq_id[j] = j;
|
||||
batch.logits[j] = true;
|
||||
batch.token[j] = 0;
|
||||
batch.pos[j] = pp + i;
|
||||
batch.n_seq_id[j] = 1;
|
||||
batch.seq_id[j][0] = j;
|
||||
batch.logits[j] = true;
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue