diff --git a/common/log.h b/common/log.h index b8953fdca..3b41c1df8 100644 --- a/common/log.h +++ b/common/log.h @@ -612,6 +612,43 @@ inline std::string log_var_to_string_impl(const std::vector & var) }() \ .c_str() +#define LOG_BATCH_TOSTR_PRETTY(ctx, batch) \ + [&batch, &ctx]() \ + { \ + std::stringstream buf; \ + buf << "[ "; \ + \ + bool first = true; \ + for (int i = 0; i < batch.n_tokens; ++i) \ + { \ + if (!first) \ + buf << ", "; \ + else \ + first = false; \ + \ + auto detokenized = llama_token_to_piece(ctx, batch.token[i]); \ + \ + detokenized.erase( \ + std::remove_if( \ + detokenized.begin(), \ + detokenized.end(), \ + [](const unsigned char c) { return !std::isprint(c); }), \ + detokenized.end()); \ + \ + buf \ + << "\n" << std::to_string(i) \ + << ":token '" << detokenized << "'" \ + << ":pos " << std::to_string(batch.pos[i]) \ + << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) \ + << ":seq_id " << std::to_string(batch.seq_id[i][0]) \ + << ":logits " << std::to_string(batch.logits[i]); \ + } \ + buf << " ]"; \ + \ + return buf.str(); \ + }() \ + .c_str() + #ifdef LOG_DISABLE_LOGS #undef LOG diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 3e1e0716d..3820f821d 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { return 1; } - llama_batch batch = llama_batch_init(n_kv_max, 0); + llama_batch batch = llama_batch_init(n_kv_max, 0, 1); // decode in batches of ctx_params.n_batch tokens auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { @@ -123,11 +123,12 @@ int main(int argc, char ** argv) { llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, - batch.seq_id + i, - batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; @@ -146,10 +147,11 @@ int main(int argc, char ** argv) { batch.n_tokens = 16; for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + batch.token[i] = 0; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = false; } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -177,10 +179,11 @@ int main(int argc, char ** argv) { batch.n_tokens = is_pp_shared ? pp : pl*pp; for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + batch.token[i] = 0; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = false; } batch.logits[batch.n_tokens - 1] = true; @@ -207,10 +210,11 @@ int main(int argc, char ** argv) { batch.n_tokens = pl; for (int j = 0; j < pl; ++j) { - batch.token[j] = 0; - batch.pos[j] = pp + i; - batch.seq_id[j] = j; - batch.logits[j] = true; + batch.token[j] = 0; + batch.pos[j] = pp + i; + batch.n_seq_id[j] = 1; + batch.seq_id[j][0] = j; + batch.logits[j] = true; } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index a88e022d6..4b4e25176 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -97,10 +97,10 @@ int main(int argc, char ** argv) { fflush(stderr); - // create a llama_batch with size 512 + // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1); // evaluate the initial prompt batch.n_tokens = tokens_list.size(); @@ -199,10 +199,11 @@ int main(int argc, char ** argv) { streams[i] += llama_token_to_piece(ctx, new_token_id); // push this new token for next evaluation - batch.token [batch.n_tokens] = new_token_id; - batch.pos [batch.n_tokens] = n_cur; - batch.seq_id[batch.n_tokens] = i; - batch.logits[batch.n_tokens] = true; + batch.token [batch.n_tokens] = new_token_id; + batch.pos [batch.n_tokens] = n_cur; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = i; + batch.logits [batch.n_tokens] = true; i_batch[i] = batch.n_tokens; diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 87a5a1c26..3ce33842c 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h index 79e237c86..8ed8f215c 100644 --- a/examples/llava/llava-utils.h +++ b/examples/llava/llava-utils.h @@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; + llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; if (llama_decode(ctx_llama, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 14dacc780..1ac730cc7 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -120,7 +120,7 @@ int main(int argc, char ** argv) { const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; // GG: are we sure that the should be a trailing whitespace at the end of this string? - eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past); + eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past); eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 165315db0..78dbd1fb0 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch batch = llama_batch_init(n_ctx, 0); + llama_batch batch = llama_batch_init(n_ctx, 0, 1); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -188,10 +188,11 @@ int main(int argc, char ** argv) { batch.n_tokens = n_tokens_system; for (int32_t i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = tokens_system[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + batch.token[i] = tokens_system[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = false; } if (llama_decode(ctx, batch) != 0) { @@ -218,10 +219,11 @@ int main(int argc, char ** argv) { continue; } - batch.token [batch.n_tokens] = client.sampled; - batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; - batch.seq_id[batch.n_tokens] = client.id; - batch.logits[batch.n_tokens] = true; + batch.token [batch.n_tokens] = client.sampled; + batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = client.id; + batch.logits [batch.n_tokens] = true; client.n_decoded += 1; client.i_batch = batch.n_tokens; @@ -258,10 +260,11 @@ int main(int argc, char ** argv) { tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - batch.token [batch.n_tokens] = tokens_prompt[i]; - batch.pos [batch.n_tokens] = i + n_tokens_system; - batch.seq_id[batch.n_tokens] = client.id; - batch.logits[batch.n_tokens] = false; + batch.token [batch.n_tokens] = tokens_prompt[i]; + batch.pos [batch.n_tokens] = i + n_tokens_system; + batch.n_seq_id[batch.n_tokens] = client.id; + batch.seq_id [batch.n_tokens][0] = client.id; + batch.logits [batch.n_tokens] = false; batch.n_tokens += 1; } @@ -305,11 +308,12 @@ int main(int argc, char ** argv) { llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, - batch.seq_id + i, - batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 24fb16b78..55385f566 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -92,7 +92,7 @@ int main(int argc, char ** argv) { // create a llama_batch with size 512 // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(512, 0); + llama_batch batch = llama_batch_init(512, 0, 1); // evaluate the initial prompt batch.n_tokens = tokens_list.size(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index c3e97d71f..7c616fb4c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -10,9 +10,19 @@ #include struct seq_draft { + bool active = false; + bool drafting = false; + bool skip = false; + + int i_batch_dft = 0; + std::vector i_batch_tgt; + std::vector tokens; struct llama_grammar * grammar = NULL; + + std::vector last_tokens; + struct llama_sampling_context ctx_sampling; }; int main(int argc, char ** argv) { @@ -27,6 +37,9 @@ int main(int argc, char ** argv) { return 1; } + // max number of parallel drafting sequences (i.e. tree branches) + int n_seq_dft = 8; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -97,25 +110,11 @@ int main(int argc, char ** argv) { int n_past_tgt = inp.size(); int n_past_dft = inp.size(); - std::vector drafted; - - std::vector last_tokens(n_ctx); - std::fill(last_tokens.begin(), last_tokens.end(), 0); - - for (auto & id : inp) { - last_tokens.erase(last_tokens.begin()); - last_tokens.push_back(id); - } - - std::vector candidates; - candidates.reserve(n_vocab); - // used to determine end of generation bool has_eos = false; // grammar stuff - struct llama_grammar * grammar_dft = NULL; - struct llama_grammar * grammar_tgt = NULL; + struct llama_grammar * grammar = NULL; grammar_parser::parse_state parsed_grammar; @@ -128,21 +127,69 @@ int main(int argc, char ** argv) { } std::vector grammar_rules(parsed_grammar.c_rules()); - grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt); + // target model sampling context + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); + + // TODO: move to llama_sampling_state + std::vector candidates; + candidates.reserve(n_vocab); + + std::vector last_tokens; + last_tokens.resize(n_ctx); + std::fill(last_tokens.begin(), last_tokens.end(), 0); + + for (auto & id : inp) { + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); + } + + // draft sequence data + std::vector drafts(n_seq_dft); + + for (int i = 0; i < n_seq_dft; ++i) { + { + auto & last_tokens = drafts[i].last_tokens; + + last_tokens.resize(n_ctx); + std::fill(last_tokens.begin(), last_tokens.end(), 0); + + for (auto & id : inp) { + last_tokens.erase(last_tokens.begin()); + last_tokens.push_back(id); + } + } + + drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar); + } + + llama_batch batch_dft = llama_batch_init(512, 0, 1); + llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft); const auto t_dec_start = ggml_time_us(); + drafts[0].i_batch_tgt.resize(1); + drafts[0].i_batch_tgt[0] = 0; + while (true) { - LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted)); + for (int i = 0; i < n_seq_dft; ++i) { + if (!drafts[i].active) continue; + + const auto & tokens = drafts[i].tokens; + + LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens)); + } int i_dft = 0; + int i_keep = 0; while (true) { + LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]); + // sample from the target model - llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -151,6 +198,7 @@ int main(int argc, char ** argv) { //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); const std::string token_str = llama_token_to_piece(ctx_tgt, id); + printf("%s", token_str.c_str()); fflush(stdout); @@ -160,53 +208,71 @@ int main(int argc, char ** argv) { ++n_predict; - // check if the draft matches the target - if (i_dft < (int) drafted.size() && id == drafted[i_dft]) { - LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); - ++n_accept; - ++n_past_tgt; - ++n_past_dft; - ++i_dft; - - continue; - } - - // the drafted token was rejected or we are out of drafted tokens - - if (i_dft < (int) drafted.size()) { - LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n", - i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str()); - } else { - LOG("out of drafted tokens\n"); - } - - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); - ++n_past_dft; - - // heuristic for n_draft + // check if the target token matches any of the drafts { - const int n_draft_cur = (int) drafted.size(); - const bool all_accepted = i_dft == n_draft_cur; + bool matches = false; - LOG("n_draft = %d\n", n_draft); - LOG("n_draft_cur = %d\n", n_draft_cur); - LOG("i_dft = %d\n", i_dft); - LOG("all_accepted = %d\n", all_accepted); + for (int i = 0; i < n_seq_dft; ++i) { + if (!drafts[i].active) continue; - if (all_accepted && n_draft == n_draft_cur) { - LOG(" - max drafted tokens accepted - n_draft += 8\n"); - n_draft = std::min(30, n_draft + 8); - } else if (all_accepted) { - LOG(" - partially drafted tokens accepted - no change\n"); - } else { - LOG(" - drafted token rejected - n_draft -= 1\n"); - n_draft = std::max(2, n_draft - 1); + if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str()); + + i_keep = i; + matches = true; + } else { + drafts[i].active = false; + } + } + + if (matches) { + ++n_accept; + ++n_past_tgt; + ++n_past_dft; + ++i_dft; + + continue; } } - drafted.clear(); - drafted.push_back(id); + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + + // TODO: simplify + { + LOG("keeping sequence %d\n", i_keep); + + llama_kv_cache_seq_keep(ctx_dft, i_keep); + llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_dft, 0); + + llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1); + llama_kv_cache_seq_keep(ctx_tgt, i_keep); + llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_tgt, 0); + } + + for (int i = 0; i < n_seq_dft; ++i) { + drafts[i].active = false; + drafts[i].tokens.clear(); + drafts[i].i_batch_tgt.clear(); + } + // note: will be erased after the speculation phase + drafts[0].tokens.push_back(id); + drafts[0].i_batch_tgt.push_back(0); + + { + batch_dft.n_tokens = 1; + + batch_dft.token[0] = id; + batch_dft.pos[0] = n_past_dft; + batch_dft.n_seq_id[0] = 1; + batch_dft.seq_id[0][0] = 0; + batch_dft.logits[0] = true; + } + + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + llama_decode(ctx_dft, batch_dft); + ++n_past_dft; break; } @@ -215,73 +281,202 @@ int main(int argc, char ** argv) { break; } - if (grammar_tgt) { - if (grammar_dft) { - llama_grammar_free(grammar_dft); + if (grammar) { + for (int i = 0; i < n_seq_dft; ++i) { + auto * grammar_dft = drafts[i].grammar; + if (grammar_dft) { + llama_grammar_free(grammar_dft); + } + + grammar_dft = llama_grammar_copy(ctx_sampling.grammar); + + LOG("copied target grammar to draft %d grammar\n", i); } - - grammar_dft = llama_grammar_copy(ctx_sampling.grammar); - - LOG("copied target grammar to draft grammar\n"); } - // sample n_draft tokens from the draft model using greedy decoding + int n_seq_cur = 1; int n_past_cur = n_past_dft; + + for (int i = 0; i < n_seq_dft; ++i) { + drafts[i].active = false; + drafts[i].drafting = false; + } + drafts[0].active = true; + drafts[0].drafting = true; + drafts[0].i_batch_dft = 0; + + batch_tgt.n_tokens = 1; + batch_tgt.token[0] = drafts[0].tokens[0]; + batch_tgt.pos[0] = n_past_tgt; + batch_tgt.n_seq_id[0] = 1; + batch_tgt.seq_id[0][0] = 0; + batch_tgt.logits[0] = true; + + // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - float * logits = llama_get_logits(ctx_dft); + batch_dft.n_tokens = 0; - candidates.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].skip = false; } - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].drafting || drafts[s].skip) continue; - if (grammar_dft != NULL) { - llama_sample_grammar(ctx_dft, &cur_p, grammar_dft); + auto & grammar = drafts[s].grammar; + auto & i_batch_dft = drafts[s].i_batch_dft; + + float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft); + + // TODO: optimize + candidates.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + + if (grammar != NULL) { + llama_sample_grammar(ctx_dft, &cur_p, grammar); + } + + // computes softmax and sorts the candidates + llama_sample_softmax(ctx_dft, &cur_p); + + for (int k = 0; k < 3; ++k) { + LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str()); + } + + // TODO: make this configurable + if (cur_p.data[0].p < 0.1) { + //if (cur_p.data[0].p < 2*cur_p.data[1].p) { + LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p); + drafts[s].drafting = false; + continue; + } + + std::vector sa(1, s); + + for (int f = 1; f < 8; ++f) { + // TODO: make this configurable + if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) { + LOG("splitting seq %3d into %3d\n", s, n_seq_cur); + + llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + + for (int t = 0; t < batch_tgt.n_tokens; ++t) { + for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { + if (batch_tgt.seq_id[t][p] == s) { + batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; + batch_tgt.n_seq_id[t]++; + break; + } + } + } + + drafts[n_seq_cur] = drafts[s]; + drafts[n_seq_cur].skip = true; + // TODO: grammar + + sa.push_back(n_seq_cur); + n_seq_cur++; + } else { + break; + } + } + + // add drafted token for each sequence + for (int is = 0; is < (int) sa.size(); ++is) { + const llama_token id = cur_p.data[is].id; + + int s = sa[is]; + + auto & drafted = drafts[s].tokens; + //auto & grammar = drafts[s].grammar; + + auto & i_batch_dft = drafts[s].i_batch_dft; + auto & i_batch_tgt = drafts[s].i_batch_tgt; + + drafted.push_back(id); + + // add unique drafted tokens to the target batch + batch_tgt.token [batch_tgt.n_tokens] = id; + batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1; + batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1; + batch_tgt.seq_id [batch_tgt.n_tokens][0] = s; + batch_tgt.logits [batch_tgt.n_tokens] = true; + + i_batch_tgt.push_back(batch_tgt.n_tokens); + + batch_tgt.n_tokens++; + + // no need to evaluate the last drafted token, since we won't use the result + if (i == n_draft - 1) { + drafts[s].drafting = false; + continue; + } + + // add the token to the batch for batched decoding with the draft model + batch_dft.token [batch_dft.n_tokens] = id; + batch_dft.pos [batch_dft.n_tokens] = n_past_cur; + batch_dft.n_seq_id[batch_dft.n_tokens] = 1; + batch_dft.seq_id [batch_dft.n_tokens][0] = s; + batch_dft.logits [batch_dft.n_tokens] = true; + + i_batch_dft = batch_dft.n_tokens; + + batch_dft.n_tokens++; + } } - // computes softmax and sorts the candidates - llama_sample_softmax(ctx_dft, &cur_p); - - for (int i = 0; i < 3; ++i) { - LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str()); - } - - // TODO: better logic? - if (cur_p.data[0].p < 2*cur_p.data[1].p) { - LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p); + // no sequence is drafting anymore + if (batch_dft.n_tokens == 0) { break; } - // drafted token - const llama_token id = cur_p.data[0].id; - - drafted.push_back(id); + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch_dft); + ++n_past_cur; ++n_drafted; - // no need to evaluate the last drafted token, since we won't use the result - if (i == n_draft - 1) { - break; + // update grammar + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].drafting) continue; + + auto & drafted = drafts[s].tokens; + auto & grammar = drafts[s].grammar; + + if (grammar != NULL) { + llama_grammar_accept_token(ctx_dft, grammar, drafted.back()); + } } - // evaluate the drafted token on the draft model - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); - llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); - ++n_past_cur; - - if (grammar_dft != NULL) { - llama_grammar_accept_token(ctx_dft, grammar_dft, id); + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + if (batch_tgt.n_tokens >= n_draft) { + break; } } // evaluate the target model on the drafted tokens - llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); - llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); - ++n_past_tgt; + { + llama_kv_cache_seq_keep(ctx_tgt, 0); + for (int s = 1; s < n_seq_dft; ++s) { + llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + } - // the first token is always proposed by the traget model before the speculation loop - drafted.erase(drafted.begin()); + //LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt)); + llama_decode(ctx_tgt, batch_tgt); + ++n_past_tgt; + } + + // the first token is always proposed by the traget model before the speculation loop so we erase it here + for (int i = 0; i < n_seq_dft; ++i) { + if (!drafts[i].active) continue; + + drafts[i].tokens.erase(drafts[i].tokens.begin()); + } } auto t_dec_end = ggml_time_us(); @@ -291,7 +486,6 @@ int main(int argc, char ** argv) { LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - // TODO: make sure these numbers are computed correctly LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_predict = %d\n", n_predict); @@ -305,15 +499,20 @@ int main(int argc, char ** argv) { LOG_TEE("\ntarget:\n"); llama_print_timings(ctx_tgt); + llama_batch_free(batch_dft); + llama_free(ctx_tgt); llama_free_model(model_tgt); llama_free(ctx_dft); llama_free_model(model_dft); - if (grammar_dft != NULL) { - llama_grammar_free(grammar_dft); - llama_grammar_free(grammar_tgt); + if (grammar) { + llama_grammar_free(grammar); + + for (int i = 0; i < n_seq_dft; ++i) { + llama_grammar_free(drafts[i].grammar); + } } llama_backend_free(); diff --git a/llama.cpp b/llama.cpp index 7ed872237..fc2d245b3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1447,7 +1447,10 @@ static bool llama_kv_cache_find_slot( for (uint32_t i = 0; i < n_tokens; i++) { cache.cells[cache.head + i].pos = batch.pos[i]; - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } } return true; @@ -1527,6 +1530,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); } } @@ -3080,7 +3086,7 @@ static struct ggml_cgraph * llm_build_llama( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3466,7 +3472,7 @@ static struct ggml_cgraph * llm_build_baichaun( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -3865,7 +3871,7 @@ static struct ggml_cgraph * llm_build_refact( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4217,7 +4223,7 @@ static struct ggml_cgraph * llm_build_falcon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4569,7 +4575,7 @@ static struct ggml_cgraph * llm_build_starcoder( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -4800,7 +4806,7 @@ static struct ggml_cgraph * llm_build_persimmon( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; @@ -5198,7 +5204,7 @@ static struct ggml_cgraph * llm_build_bloom( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5466,7 +5472,7 @@ static struct ggml_cgraph * llm_build_mpt( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; for (int i = 0; i < n_kv; ++i) { if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { @@ -5767,8 +5773,11 @@ static int llama_decode_internal( // helpers for smoother batch API transistion // after deprecating the llama_eval calls, these will be removed - std::vector pos; - std::vector seq_id; + std::vector pos; + + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; if (batch.pos == nullptr) { pos.resize(n_tokens); @@ -5780,12 +5789,18 @@ static int llama_decode_internal( } if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = batch.all_seq_id; + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); } - batch.seq_id = seq_id.data(); + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); } if (!llama_kv_cache_find_slot(kv_self, batch)) { @@ -8837,6 +8852,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); } @@ -9289,7 +9307,7 @@ int llama_eval_embd( int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { @@ -9310,20 +9328,21 @@ struct llama_batch llama_batch_get_one( llama_pos pos_0, llama_seq_id seq_id) { return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, }; } -struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { - llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); @@ -9331,19 +9350,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return batch; } void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.seq_id) free(batch.seq_id); - if (batch.logits) free(batch.logits); + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; i < batch.n_tokens; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); } int llama_decode( diff --git a/llama.h b/llama.h index a78015ada..941720c2a 100644 --- a/llama.h +++ b/llama.h @@ -133,11 +133,12 @@ extern "C" { typedef struct llama_batch { int32_t n_tokens; - llama_token * token; - float * embd; - llama_pos * pos; - llama_seq_id * seq_id; - int8_t * logits; + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -446,7 +447,8 @@ extern "C" { llama_pos pos_0, llama_seq_id seq_id); - // Allocates a batch of tokens on the heap + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids // The batch has to be freed with llama_batch_free() // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token @@ -454,7 +456,8 @@ extern "C" { // All members are left uninitialized LLAMA_API struct llama_batch llama_batch_init( int32_t n_tokens, - int32_t embd); + int32_t embd, + int32_t n_seq_max); // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch);