common : add llama_batch_add() and llama_batch_clear() helpers

This commit is contained in:
Georgi Gerganov 2023-10-16 12:41:33 +03:00
parent 005949109d
commit 360a333145
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 98 additions and 122 deletions

View file

@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

View file

@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//
// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// Batch utils
void llama_batch_clear(struct llama_batch & batch);
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Vocab utils
//

View file

@ -53,6 +53,19 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
if (src->grammar) {
dst->grammar = llama_grammar_copy(src->grammar);
}
dst->prev = src->prev;
}
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,

View file

@ -67,6 +67,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call

View file

@ -144,14 +144,8 @@ int main(int argc, char ** argv) {
// warm up
{
batch.n_tokens = 16;
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
for (int i = 0; i < 16; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@ -176,14 +170,12 @@ int main(int argc, char ** argv) {
continue;
}
batch.n_tokens = is_pp_shared ? pp : pl*pp;
llama_batch_clear(batch);
for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
const int n_tokens = is_pp_shared ? pp : pl*pp;
for (int i = 0; i < n_tokens; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
@ -207,14 +199,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
batch.n_tokens = pl;
llama_batch_clear(batch);
for (int j = 0; j < pl; ++j) {
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.n_seq_id[j] = 1;
batch.seq_id[j][0] = j;
batch.logits[j] = true;
llama_batch_add(batch, 0, pp + i, { j }, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {

View file

@ -99,19 +99,13 @@ int main(int argc, char ** argv) {
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = tokens_list[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
for (size_t i = 0; i < tokens_list.size(); ++i) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
@ -147,7 +141,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// prepare the next batch
batch.n_tokens = 0;
llama_batch_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@ -199,16 +193,10 @@ int main(int argc, char ** argv) {
streams[i] += llama_token_to_piece(ctx, new_token_id);
// push this new token for next evaluation
batch.token [batch.n_tokens] = new_token_id;
batch.pos [batch.n_tokens] = n_cur;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = i;
batch.logits [batch.n_tokens] = true;
i_batch[i] = batch.n_tokens;
batch.n_tokens += 1;
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, { i }, true);
n_decode += 1;
}

View file

@ -518,7 +518,6 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;

View file

@ -183,14 +183,8 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch.n_tokens = n_tokens_system;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = false;
for (int32_t i = 0; i < n_tokens_system; ++i) {
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
}
if (llama_decode(ctx, batch) != 0) {
@ -209,7 +203,7 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
batch.n_tokens = 0;
llama_batch_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@ -217,16 +211,11 @@ int main(int argc, char ** argv) {
continue;
}
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = client.id;
batch.logits [batch.n_tokens] = true;
client.n_decoded += 1;
client.i_batch = batch.n_tokens;
batch.n_tokens += 1;
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
@ -258,12 +247,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = client.id;
batch.logits [batch.n_tokens] = false;
batch.n_tokens += 1;
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
}
// extract the logits only for the last token

View file

@ -219,18 +219,12 @@ int main(int argc, char ** argv) {
drafts[0].tokens.push_back(id);
drafts[0].i_batch_tgt.push_back(0);
{
batch_dft.n_tokens = 1;
batch_dft.token [0] = id;
batch_dft.pos [0] = n_past_dft;
batch_dft.n_seq_id[0] = 1;
batch_dft.seq_id [0][0] = 0;
batch_dft.logits [0] = true;
}
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode(ctx_dft, batch_dft);
llama_decode (ctx_dft, batch_dft);
++n_past_dft;
break;
@ -240,20 +234,7 @@ int main(int argc, char ** argv) {
break;
}
for (int i = 0; i < n_seq_dft; ++i) {
if (ctx_sampling->grammar) {
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
LOG("copied target grammar to draft %d grammar\n", 0);
}
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
}
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
@ -266,12 +247,8 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
batch_tgt.n_tokens = 1;
batch_tgt.token [0] = drafts[0].tokens[0];
batch_tgt.pos [0] = n_past_tgt;
batch_tgt.n_seq_id[0] = 1;
batch_tgt.seq_id [0][0] = 0;
batch_tgt.logits [0] = true;
llama_batch_clear(batch_tgt);
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
@ -313,6 +290,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
@ -324,19 +302,18 @@ int main(int argc, char ** argv) {
}
// copy the draft state
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].skip = true;
drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
if (ctx_sampling->grammar) {
drafts[n_seq_cur].ctx_sampling->grammar =
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
}
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
sa.push_back(n_seq_cur);
n_seq_cur++;
} else {
break;
@ -354,19 +331,14 @@ int main(int argc, char ** argv) {
auto & i_batch_dft = drafts[s].i_batch_dft;
auto & i_batch_tgt = drafts[s].i_batch_tgt;
drafted.push_back(id);
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
// add unique drafted tokens to the target batch
batch_tgt.token [batch_tgt.n_tokens] = id;
batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1;
batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1;
batch_tgt.seq_id [batch_tgt.n_tokens][0] = s;
batch_tgt.logits [batch_tgt.n_tokens] = true;
drafted.push_back(id);
// add unique drafted tokens to the target batch
i_batch_tgt.push_back(batch_tgt.n_tokens);
batch_tgt.n_tokens++;
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
// no need to evaluate the last drafted token, since we won't use the result
if (batch_tgt.n_tokens == n_draft) {
@ -375,15 +347,9 @@ int main(int argc, char ** argv) {
}
// add the token to the batch for batched decoding with the draft model
batch_dft.token [batch_dft.n_tokens] = id;
batch_dft.pos [batch_dft.n_tokens] = n_past_cur;
batch_dft.n_seq_id[batch_dft.n_tokens] = 1;
batch_dft.seq_id [batch_dft.n_tokens][0] = s;
batch_dft.logits [batch_dft.n_tokens] = true;
i_batch_dft = batch_dft.n_tokens;
batch_dft.n_tokens++;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
}
}
@ -444,6 +410,11 @@ int main(int argc, char ** argv) {
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt);
llama_sampling_free(ctx_sampling);
for (int i = 0; i < n_seq_dft; ++i) {
llama_sampling_free(drafts[i].ctx_sampling);
}
llama_batch_free(batch_dft);
llama_free(ctx_tgt);
@ -452,11 +423,6 @@ int main(int argc, char ** argv) {
llama_free(ctx_dft);
llama_free_model(model_dft);
llama_sampling_free(ctx_sampling);
for (int i = 0; i < n_seq_dft; ++i) {
llama_sampling_free(drafts[i].ctx_sampling);
}
llama_backend_free();
fprintf(stderr, "\n\n");

View file

@ -9342,7 +9342,7 @@ struct llama_batch llama_batch_get_one(
}
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);