common : add llama_batch_add() and llama_batch_clear() helpers

This commit is contained in:
Georgi Gerganov 2023-10-16 12:41:33 +03:00
parent 005949109d
commit 360a333145
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 98 additions and 122 deletions

View file

@ -219,18 +219,12 @@ int main(int argc, char ** argv) {
drafts[0].tokens.push_back(id);
drafts[0].i_batch_tgt.push_back(0);
{
batch_dft.n_tokens = 1;
batch_dft.token [0] = id;
batch_dft.pos [0] = n_past_dft;
batch_dft.n_seq_id[0] = 1;
batch_dft.seq_id [0][0] = 0;
batch_dft.logits [0] = true;
}
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode(ctx_dft, batch_dft);
llama_decode (ctx_dft, batch_dft);
++n_past_dft;
break;
@ -240,20 +234,7 @@ int main(int argc, char ** argv) {
break;
}
for (int i = 0; i < n_seq_dft; ++i) {
if (ctx_sampling->grammar) {
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
LOG("copied target grammar to draft %d grammar\n", 0);
}
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
}
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
@ -266,12 +247,8 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
batch_tgt.n_tokens = 1;
batch_tgt.token [0] = drafts[0].tokens[0];
batch_tgt.pos [0] = n_past_tgt;
batch_tgt.n_seq_id[0] = 1;
batch_tgt.seq_id [0][0] = 0;
batch_tgt.logits [0] = true;
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
@ -313,6 +290,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
@ -324,19 +302,18 @@ int main(int argc, char ** argv) {
}
// copy the draft state
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
if (ctx_sampling->grammar) {
drafts[n_seq_cur].ctx_sampling->grammar =
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
}
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
sa.push_back(n_seq_cur);
n_seq_cur++;
} else {
break;
@ -354,19 +331,14 @@ int main(int argc, char ** argv) {
auto & i_batch_dft = drafts[s].i_batch_dft;
auto & i_batch_tgt = drafts[s].i_batch_tgt;
drafted.push_back(id);
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
// add unique drafted tokens to the target batch
batch_tgt.token [batch_tgt.n_tokens] = id;
batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1;
batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1;
batch_tgt.seq_id [batch_tgt.n_tokens][0] = s;
batch_tgt.logits [batch_tgt.n_tokens] = true;
drafted.push_back(id);
// add unique drafted tokens to the target batch
i_batch_tgt.push_back(batch_tgt.n_tokens);
batch_tgt.n_tokens++;
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
// no need to evaluate the last drafted token, since we won't use the result
if (batch_tgt.n_tokens == n_draft) {
@ -375,15 +347,9 @@ int main(int argc, char ** argv) {
}
// add the token to the batch for batched decoding with the draft model
batch_dft.token [batch_dft.n_tokens] = id;
batch_dft.pos [batch_dft.n_tokens] = n_past_cur;
batch_dft.n_seq_id[batch_dft.n_tokens] = 1;
batch_dft.seq_id [batch_dft.n_tokens][0] = s;
batch_dft.logits [batch_dft.n_tokens] = true;
i_batch_dft = batch_dft.n_tokens;
batch_dft.n_tokens++;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
}
}
@ -444,6 +410,11 @@ int main(int argc, char ** argv) {
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt);
llama_sampling_free(ctx_sampling);
for (int i = 0; i < n_seq_dft; ++i) {
llama_sampling_free(drafts[i].ctx_sampling);
}
llama_batch_free(batch_dft);
llama_free(ctx_tgt);
@ -452,11 +423,6 @@ int main(int argc, char ** argv) {
llama_free(ctx_dft);
llama_free_model(model_dft);
llama_sampling_free(ctx_sampling);
for (int i = 0; i < n_seq_dft; ++i) {
llama_sampling_free(drafts[i].ctx_sampling);
}
llama_backend_free();
fprintf(stderr, "\n\n");