common : add llama_batch_add() and llama_batch_clear() helpers
This commit is contained in:
parent
005949109d
commit
360a333145
10 changed files with 98 additions and 122 deletions
|
@ -219,18 +219,12 @@ int main(int argc, char ** argv) {
|
|||
drafts[0].tokens.push_back(id);
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
{
|
||||
batch_dft.n_tokens = 1;
|
||||
|
||||
batch_dft.token [0] = id;
|
||||
batch_dft.pos [0] = n_past_dft;
|
||||
batch_dft.n_seq_id[0] = 1;
|
||||
batch_dft.seq_id [0][0] = 0;
|
||||
batch_dft.logits [0] = true;
|
||||
}
|
||||
llama_batch_clear(batch_dft);
|
||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
llama_decode (ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
||||
break;
|
||||
|
@ -240,20 +234,7 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (ctx_sampling->grammar) {
|
||||
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
|
||||
if (grammar_dft) {
|
||||
llama_grammar_free(grammar_dft);
|
||||
}
|
||||
|
||||
grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
|
||||
|
||||
LOG("copied target grammar to draft %d grammar\n", 0);
|
||||
}
|
||||
|
||||
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
|
||||
}
|
||||
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
|
||||
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
|
@ -266,12 +247,8 @@ int main(int argc, char ** argv) {
|
|||
drafts[0].drafting = true;
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
batch_tgt.n_tokens = 1;
|
||||
batch_tgt.token [0] = drafts[0].tokens[0];
|
||||
batch_tgt.pos [0] = n_past_tgt;
|
||||
batch_tgt.n_seq_id[0] = 1;
|
||||
batch_tgt.seq_id [0][0] = 0;
|
||||
batch_tgt.logits [0] = true;
|
||||
llama_batch_clear(batch_tgt);
|
||||
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
|
@ -313,6 +290,7 @@ int main(int argc, char ** argv) {
|
|||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
||||
if (batch_tgt.seq_id[t][p] == s) {
|
||||
|
@ -324,19 +302,18 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// copy the draft state
|
||||
drafts[n_seq_cur].active = true;
|
||||
drafts[n_seq_cur].active = true;
|
||||
drafts[n_seq_cur].drafting = true;
|
||||
drafts[n_seq_cur].skip = true;
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].skip = true;
|
||||
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
if (ctx_sampling->grammar) {
|
||||
drafts[n_seq_cur].ctx_sampling->grammar =
|
||||
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
|
||||
}
|
||||
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
|
||||
|
||||
sa.push_back(n_seq_cur);
|
||||
|
||||
n_seq_cur++;
|
||||
} else {
|
||||
break;
|
||||
|
@ -354,19 +331,14 @@ int main(int argc, char ** argv) {
|
|||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
drafted.push_back(id);
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
batch_tgt.token [batch_tgt.n_tokens] = id;
|
||||
batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1;
|
||||
batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1;
|
||||
batch_tgt.seq_id [batch_tgt.n_tokens][0] = s;
|
||||
batch_tgt.logits [batch_tgt.n_tokens] = true;
|
||||
drafted.push_back(id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
||||
batch_tgt.n_tokens++;
|
||||
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||
|
||||
// no need to evaluate the last drafted token, since we won't use the result
|
||||
if (batch_tgt.n_tokens == n_draft) {
|
||||
|
@ -375,15 +347,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// add the token to the batch for batched decoding with the draft model
|
||||
batch_dft.token [batch_dft.n_tokens] = id;
|
||||
batch_dft.pos [batch_dft.n_tokens] = n_past_cur;
|
||||
batch_dft.n_seq_id[batch_dft.n_tokens] = 1;
|
||||
batch_dft.seq_id [batch_dft.n_tokens][0] = s;
|
||||
batch_dft.logits [batch_dft.n_tokens] = true;
|
||||
|
||||
i_batch_dft = batch_dft.n_tokens;
|
||||
|
||||
batch_dft.n_tokens++;
|
||||
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -444,6 +410,11 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx_tgt);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
llama_sampling_free(drafts[i].ctx_sampling);
|
||||
}
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
|
||||
llama_free(ctx_tgt);
|
||||
|
@ -452,11 +423,6 @@ int main(int argc, char ** argv) {
|
|||
llama_free(ctx_dft);
|
||||
llama_free_model(model_dft);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
llama_sampling_free(drafts[i].ctx_sampling);
|
||||
}
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue