From 360a33314541b70735ad0f69fefdc61e524bfb35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Oct 2023 12:41:33 +0300 Subject: [PATCH] common : add llama_batch_add() and llama_batch_clear() helpers --- common/common.cpp | 21 +++++++ common/common.h | 16 ++++- common/sampling.cpp | 13 ++++ common/sampling.h | 3 + examples/batched-bench/batched-bench.cpp | 30 +++------ examples/batched/batched.cpp | 24 ++----- examples/main/main.cpp | 1 - examples/parallel/parallel.cpp | 30 +++------ examples/speculative/speculative.cpp | 80 +++++++----------------- llama.cpp | 2 +- 10 files changed, 98 insertions(+), 122 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9c4f7df20..76788817d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos, + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); diff --git a/common/common.h b/common/common.h index 36fd44166..1d1b8a508 100644 --- a/common/common.h +++ b/common/common.h @@ -70,6 +70,7 @@ struct gpt_params { std::vector antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files + // TODO: avoid tuple, use struct std::vector> lora_adapter; // lora adapter path with user defined scale std::string lora_base = ""; // base model path for the lora adapter @@ -124,10 +125,23 @@ void process_escapes(std::string& input); // Model utils // +// TODO: avoid tuplue, use struct std::tuple llama_init_from_gpt_params(gpt_params & params); -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); + +struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +// Batch utils + +void llama_batch_clear(struct llama_batch & batch); + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits); + // // Vocab utils // diff --git a/common/sampling.cpp b/common/sampling.cpp index 426629c45..388085fdc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -53,6 +53,19 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { + if (dst->grammar) { + llama_grammar_free(dst->grammar); + dst->grammar = nullptr; + } + + if (src->grammar) { + dst->grammar = llama_grammar_copy(src->grammar); + } + + dst->prev = src->prev; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 6f6bc31f9..bb3c6a63c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -67,6 +67,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); +// Copy the sampler context +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 3820f821d..c552eaa73 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -144,14 +144,8 @@ int main(int argc, char ** argv) { // warm up { - batch.n_tokens = 16; - - for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = false; + for (int i = 0; i < 16; ++i) { + llama_batch_add(batch, 0, i, { 0 }, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -176,14 +170,12 @@ int main(int argc, char ** argv) { continue; } - batch.n_tokens = is_pp_shared ? pp : pl*pp; + llama_batch_clear(batch); - for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = false; + const int n_tokens = is_pp_shared ? pp : pl*pp; + + for (int i = 0; i < n_tokens; ++i) { + llama_batch_add(batch, 0, i, { 0 }, false); } batch.logits[batch.n_tokens - 1] = true; @@ -207,14 +199,10 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - batch.n_tokens = pl; + llama_batch_clear(batch); for (int j = 0; j < pl; ++j) { - batch.token[j] = 0; - batch.pos[j] = pp + i; - batch.n_seq_id[j] = 1; - batch.seq_id[j][0] = j; - batch.logits[j] = true; + llama_batch_add(batch, 0, pp + i, { j }, true); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index d1c66af46..155212165 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -99,19 +99,13 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1); // evaluate the initial prompt - batch.n_tokens = tokens_list.size(); - - for (int32_t i = 0; i < batch.n_tokens; i++) { - batch.token[i] = tokens_list[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = false; + for (size_t i = 0; i < tokens_list.size(); ++i) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); } + GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; @@ -147,7 +141,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // prepare the next batch - batch.n_tokens = 0; + llama_batch_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -199,16 +193,10 @@ int main(int argc, char ** argv) { streams[i] += llama_token_to_piece(ctx, new_token_id); - // push this new token for next evaluation - batch.token [batch.n_tokens] = new_token_id; - batch.pos [batch.n_tokens] = n_cur; - batch.n_seq_id[batch.n_tokens] = 1; - batch.seq_id [batch.n_tokens][0] = i; - batch.logits [batch.n_tokens] = true; - i_batch[i] = batch.n_tokens; - batch.n_tokens += 1; + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { i }, true); n_decode += 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3d3386082..316c7bf05 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -518,7 +518,6 @@ int main(int argc, char ** argv) { // evaluate tokens in batches // embd is typically prepared beforehand to fit within a batch, but not always - if (ctx_guidance) { int input_size = 0; llama_token * input_buf = NULL; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 4d22ee4ce..69f9526a4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -183,14 +183,8 @@ int main(int argc, char ** argv) { { LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); - batch.n_tokens = n_tokens_system; - - for (int32_t i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = tokens_system[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = false; + for (int32_t i = 0; i < n_tokens_system; ++i) { + llama_batch_add(batch, tokens_system[i], i, { 0 }, false); } if (llama_decode(ctx, batch) != 0) { @@ -209,7 +203,7 @@ int main(int argc, char ** argv) { LOG_TEE("Processing requests ...\n\n"); while (true) { - batch.n_tokens = 0; + llama_batch_clear(batch); // decode any currently ongoing sequences for (auto & client : clients) { @@ -217,16 +211,11 @@ int main(int argc, char ** argv) { continue; } - batch.token [batch.n_tokens] = client.sampled; - batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; - batch.n_seq_id[batch.n_tokens] = 1; - batch.seq_id [batch.n_tokens][0] = client.id; - batch.logits [batch.n_tokens] = true; - - client.n_decoded += 1; client.i_batch = batch.n_tokens; - batch.n_tokens += 1; + llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true); + + client.n_decoded += 1; } if (batch.n_tokens == 0) { @@ -258,12 +247,7 @@ int main(int argc, char ** argv) { tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - batch.token [batch.n_tokens] = tokens_prompt[i]; - batch.pos [batch.n_tokens] = i + n_tokens_system; - batch.n_seq_id[batch.n_tokens] = 1; - batch.seq_id [batch.n_tokens][0] = client.id; - batch.logits [batch.n_tokens] = false; - batch.n_tokens += 1; + llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false); } // extract the logits only for the last token diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index ac4b13796..e873214ee 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -219,18 +219,12 @@ int main(int argc, char ** argv) { drafts[0].tokens.push_back(id); drafts[0].i_batch_tgt.push_back(0); - { - batch_dft.n_tokens = 1; - - batch_dft.token [0] = id; - batch_dft.pos [0] = n_past_dft; - batch_dft.n_seq_id[0] = 1; - batch_dft.seq_id [0][0] = 0; - batch_dft.logits [0] = true; - } + llama_batch_clear(batch_dft); + llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - llama_decode(ctx_dft, batch_dft); + llama_decode (ctx_dft, batch_dft); + ++n_past_dft; break; @@ -240,20 +234,7 @@ int main(int argc, char ** argv) { break; } - for (int i = 0; i < n_seq_dft; ++i) { - if (ctx_sampling->grammar) { - auto & grammar_dft = drafts[0].ctx_sampling->grammar; - if (grammar_dft) { - llama_grammar_free(grammar_dft); - } - - grammar_dft = llama_grammar_copy(ctx_sampling->grammar); - - LOG("copied target grammar to draft %d grammar\n", 0); - } - - drafts[i].ctx_sampling->prev = ctx_sampling->prev; - } + llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -266,12 +247,8 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - batch_tgt.n_tokens = 1; - batch_tgt.token [0] = drafts[0].tokens[0]; - batch_tgt.pos [0] = n_past_tgt; - batch_tgt.n_seq_id[0] = 1; - batch_tgt.seq_id [0][0] = 0; - batch_tgt.logits [0] = true; + llama_batch_clear(batch_tgt); + llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -313,6 +290,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { if (batch_tgt.seq_id[t][p] == s) { @@ -324,19 +302,18 @@ int main(int argc, char ** argv) { } // copy the draft state - drafts[n_seq_cur].active = true; + drafts[n_seq_cur].active = true; drafts[n_seq_cur].drafting = true; - drafts[n_seq_cur].skip = true; - drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].skip = true; + + drafts[n_seq_cur].tokens = drafts[s].tokens; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - if (ctx_sampling->grammar) { - drafts[n_seq_cur].ctx_sampling->grammar = - llama_grammar_copy(drafts[s].ctx_sampling->grammar); - } + llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); sa.push_back(n_seq_cur); + n_seq_cur++; } else { break; @@ -354,19 +331,14 @@ int main(int argc, char ** argv) { auto & i_batch_dft = drafts[s].i_batch_dft; auto & i_batch_tgt = drafts[s].i_batch_tgt; - drafted.push_back(id); llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); - // add unique drafted tokens to the target batch - batch_tgt.token [batch_tgt.n_tokens] = id; - batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1; - batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1; - batch_tgt.seq_id [batch_tgt.n_tokens][0] = s; - batch_tgt.logits [batch_tgt.n_tokens] = true; + drafted.push_back(id); + // add unique drafted tokens to the target batch i_batch_tgt.push_back(batch_tgt.n_tokens); - batch_tgt.n_tokens++; + llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); // no need to evaluate the last drafted token, since we won't use the result if (batch_tgt.n_tokens == n_draft) { @@ -375,15 +347,9 @@ int main(int argc, char ** argv) { } // add the token to the batch for batched decoding with the draft model - batch_dft.token [batch_dft.n_tokens] = id; - batch_dft.pos [batch_dft.n_tokens] = n_past_cur; - batch_dft.n_seq_id[batch_dft.n_tokens] = 1; - batch_dft.seq_id [batch_dft.n_tokens][0] = s; - batch_dft.logits [batch_dft.n_tokens] = true; - i_batch_dft = batch_dft.n_tokens; - batch_dft.n_tokens++; + llama_batch_add(batch_dft, id, n_past_cur, { s }, true); } } @@ -444,6 +410,11 @@ int main(int argc, char ** argv) { LOG_TEE("\ntarget:\n"); llama_print_timings(ctx_tgt); + llama_sampling_free(ctx_sampling); + for (int i = 0; i < n_seq_dft; ++i) { + llama_sampling_free(drafts[i].ctx_sampling); + } + llama_batch_free(batch_dft); llama_free(ctx_tgt); @@ -452,11 +423,6 @@ int main(int argc, char ** argv) { llama_free(ctx_dft); llama_free_model(model_dft); - llama_sampling_free(ctx_sampling); - for (int i = 0; i < n_seq_dft; ++i) { - llama_sampling_free(drafts[i].ctx_sampling); - } - llama_backend_free(); fprintf(stderr, "\n\n"); diff --git a/llama.cpp b/llama.cpp index fc2d245b3..7eced284b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9342,7 +9342,7 @@ struct llama_batch llama_batch_get_one( } struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);