speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
c67fe68e41
commit
0e89203b51
21 changed files with 737 additions and 578 deletions
|
@ -97,20 +97,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = tokens_list[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
for (size_t i = 0; i < tokens_list.size(); ++i) {
|
||||
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
@ -146,7 +141,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (n_cur <= n_len) {
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
llama_batch_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
|
@ -198,15 +193,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = i;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
// push this new token for next evaluation
|
||||
llama_batch_add(batch, new_token_id, n_cur, { i }, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue