speculative : add tree-based sampling support
ggml-ci
This commit is contained in:
parent
5261aee8d8
commit
4de5a2d473
11 changed files with 469 additions and 192 deletions
|
@ -97,10 +97,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
@ -199,10 +199,11 @@ int main(int argc, char ** argv) {
|
|||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = i;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.n_seq_id[batch.n_tokens] = 1;
|
||||
batch.seq_id [batch.n_tokens][0] = i;
|
||||
batch.logits [batch.n_tokens] = true;
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue