speculative : add tree-based sampling support
ggml-ci
This commit is contained in:
parent
5261aee8d8
commit
4de5a2d473
11 changed files with 469 additions and 192 deletions
|
@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0);
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
|
@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
|
@ -146,10 +147,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = 16;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
|
@ -177,10 +179,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = is_pp_shared ? pp : pl*pp;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
batch.token[i] = 0;
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
|
@ -207,10 +210,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = pl;
|
||||
|
||||
for (int j = 0; j < pl; ++j) {
|
||||
batch.token[j] = 0;
|
||||
batch.pos[j] = pp + i;
|
||||
batch.seq_id[j] = j;
|
||||
batch.logits[j] = true;
|
||||
batch.token[j] = 0;
|
||||
batch.pos[j] = pp + i;
|
||||
batch.n_seq_id[j] = 1;
|
||||
batch.seq_id[j][0] = j;
|
||||
batch.logits[j] = true;
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
|
|
|
@ -97,10 +97,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
@ -199,10 +199,11 @@ int main(int argc, char ** argv) {
|
|||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = i;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.n_seq_id[batch.n_tokens] = 1;
|
||||
batch.seq_id [batch.n_tokens][0] = i;
|
||||
batch.logits [batch.n_tokens] = true;
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
@ -120,7 +120,7 @@ int main(int argc, char ** argv) {
|
|||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
|
||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past);
|
||||
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
||||
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
||||
|
|
|
@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0);
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
@ -188,10 +188,11 @@ int main(int argc, char ** argv) {
|
|||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
|
@ -218,10 +219,11 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
batch.token [batch.n_tokens] = client.sampled;
|
||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
batch.token [batch.n_tokens] = client.sampled;
|
||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||
batch.n_seq_id[batch.n_tokens] = 1;
|
||||
batch.seq_id [batch.n_tokens][0] = client.id;
|
||||
batch.logits [batch.n_tokens] = true;
|
||||
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
@ -258,10 +260,11 @@ int main(int argc, char ** argv) {
|
|||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||
batch.n_seq_id[batch.n_tokens] = client.id;
|
||||
batch.seq_id [batch.n_tokens][0] = client.id;
|
||||
batch.logits [batch.n_tokens] = false;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
|
@ -305,11 +308,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
|
|||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0);
|
||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
|
|
@ -10,9 +10,19 @@
|
|||
#include <vector>
|
||||
|
||||
struct seq_draft {
|
||||
bool active = false;
|
||||
bool drafting = false;
|
||||
bool skip = false;
|
||||
|
||||
int i_batch_dft = 0;
|
||||
std::vector<int> i_batch_tgt;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
struct llama_grammar * grammar = NULL;
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
struct llama_sampling_context ctx_sampling;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
@ -27,6 +37,9 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
// max number of parallel drafting sequences (i.e. tree branches)
|
||||
int n_seq_dft = 8;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("speculative", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
|
@ -97,25 +110,11 @@ int main(int argc, char ** argv) {
|
|||
int n_past_tgt = inp.size();
|
||||
int n_past_dft = inp.size();
|
||||
|
||||
std::vector<llama_token> drafted;
|
||||
|
||||
std::vector<llama_token> last_tokens(n_ctx);
|
||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||
|
||||
for (auto & id : inp) {
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
// grammar stuff
|
||||
struct llama_grammar * grammar_dft = NULL;
|
||||
struct llama_grammar * grammar_tgt = NULL;
|
||||
struct llama_grammar * grammar = NULL;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
|
@ -128,21 +127,69 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
|
||||
// target model sampling context
|
||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
|
||||
// TODO: move to llama_sampling_state
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
last_tokens.resize(n_ctx);
|
||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||
|
||||
for (auto & id : inp) {
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
}
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
{
|
||||
auto & last_tokens = drafts[i].last_tokens;
|
||||
|
||||
last_tokens.resize(n_ctx);
|
||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||
|
||||
for (auto & id : inp) {
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
}
|
||||
}
|
||||
|
||||
drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(512, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
drafts[0].i_batch_tgt.resize(1);
|
||||
drafts[0].i_batch_tgt[0] = 0;
|
||||
|
||||
while (true) {
|
||||
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
|
||||
const auto & tokens = drafts[i].tokens;
|
||||
|
||||
LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens));
|
||||
}
|
||||
|
||||
int i_dft = 0;
|
||||
int i_keep = 0;
|
||||
|
||||
while (true) {
|
||||
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
|
||||
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
|
@ -151,6 +198,7 @@ int main(int argc, char ** argv) {
|
|||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
||||
|
||||
printf("%s", token_str.c_str());
|
||||
fflush(stdout);
|
||||
|
||||
|
@ -160,53 +208,71 @@ int main(int argc, char ** argv) {
|
|||
|
||||
++n_predict;
|
||||
|
||||
// check if the draft matches the target
|
||||
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
|
||||
++n_accept;
|
||||
++n_past_tgt;
|
||||
++n_past_dft;
|
||||
++i_dft;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// the drafted token was rejected or we are out of drafted tokens
|
||||
|
||||
if (i_dft < (int) drafted.size()) {
|
||||
LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
|
||||
i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
|
||||
} else {
|
||||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
|
||||
++n_past_dft;
|
||||
|
||||
// heuristic for n_draft
|
||||
// check if the target token matches any of the drafts
|
||||
{
|
||||
const int n_draft_cur = (int) drafted.size();
|
||||
const bool all_accepted = i_dft == n_draft_cur;
|
||||
bool matches = false;
|
||||
|
||||
LOG("n_draft = %d\n", n_draft);
|
||||
LOG("n_draft_cur = %d\n", n_draft_cur);
|
||||
LOG("i_dft = %d\n", i_dft);
|
||||
LOG("all_accepted = %d\n", all_accepted);
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
|
||||
if (all_accepted && n_draft == n_draft_cur) {
|
||||
LOG(" - max drafted tokens accepted - n_draft += 8\n");
|
||||
n_draft = std::min(30, n_draft + 8);
|
||||
} else if (all_accepted) {
|
||||
LOG(" - partially drafted tokens accepted - no change\n");
|
||||
} else {
|
||||
LOG(" - drafted token rejected - n_draft -= 1\n");
|
||||
n_draft = std::max(2, n_draft - 1);
|
||||
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
||||
|
||||
i_keep = i;
|
||||
matches = true;
|
||||
} else {
|
||||
drafts[i].active = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (matches) {
|
||||
++n_accept;
|
||||
++n_past_tgt;
|
||||
++n_past_dft;
|
||||
++i_dft;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
drafted.clear();
|
||||
drafted.push_back(id);
|
||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
||||
|
||||
// TODO: simplify
|
||||
{
|
||||
LOG("keeping sequence %d\n", i_keep);
|
||||
|
||||
llama_kv_cache_seq_keep(ctx_dft, i_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, i_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
drafts[i].active = false;
|
||||
drafts[i].tokens.clear();
|
||||
drafts[i].i_batch_tgt.clear();
|
||||
}
|
||||
// note: will be erased after the speculation phase
|
||||
drafts[0].tokens.push_back(id);
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
{
|
||||
batch_dft.n_tokens = 1;
|
||||
|
||||
batch_dft.token[0] = id;
|
||||
batch_dft.pos[0] = n_past_dft;
|
||||
batch_dft.n_seq_id[0] = 1;
|
||||
batch_dft.seq_id[0][0] = 0;
|
||||
batch_dft.logits[0] = true;
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
++n_past_dft;
|
||||
|
||||
break;
|
||||
}
|
||||
|
@ -215,73 +281,202 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
if (grammar_tgt) {
|
||||
if (grammar_dft) {
|
||||
llama_grammar_free(grammar_dft);
|
||||
if (grammar) {
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
auto * grammar_dft = drafts[i].grammar;
|
||||
if (grammar_dft) {
|
||||
llama_grammar_free(grammar_dft);
|
||||
}
|
||||
|
||||
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
||||
|
||||
LOG("copied target grammar to draft %d grammar\n", i);
|
||||
}
|
||||
|
||||
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
||||
|
||||
LOG("copied target grammar to draft grammar\n");
|
||||
}
|
||||
|
||||
// sample n_draft tokens from the draft model using greedy decoding
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
drafts[i].active = false;
|
||||
drafts[i].drafting = false;
|
||||
}
|
||||
drafts[0].active = true;
|
||||
drafts[0].drafting = true;
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
batch_tgt.n_tokens = 1;
|
||||
batch_tgt.token[0] = drafts[0].tokens[0];
|
||||
batch_tgt.pos[0] = n_past_tgt;
|
||||
batch_tgt.n_seq_id[0] = 1;
|
||||
batch_tgt.seq_id[0][0] = 0;
|
||||
batch_tgt.logits[0] = true;
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
float * logits = llama_get_logits(ctx_dft);
|
||||
batch_dft.n_tokens = 0;
|
||||
|
||||
candidates.clear();
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].skip = false;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].drafting || drafts[s].skip) continue;
|
||||
|
||||
if (grammar_dft != NULL) {
|
||||
llama_sample_grammar(ctx_dft, &cur_p, grammar_dft);
|
||||
auto & grammar = drafts[s].grammar;
|
||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
|
||||
float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft);
|
||||
|
||||
// TODO: optimize
|
||||
candidates.clear();
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
if (grammar != NULL) {
|
||||
llama_sample_grammar(ctx_dft, &cur_p, grammar);
|
||||
}
|
||||
|
||||
// computes softmax and sorts the candidates
|
||||
llama_sample_softmax(ctx_dft, &cur_p);
|
||||
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str());
|
||||
}
|
||||
|
||||
// TODO: make this configurable
|
||||
if (cur_p.data[0].p < 0.1) {
|
||||
//if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p);
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> sa(1, s);
|
||||
|
||||
for (int f = 1; f < 8; ++f) {
|
||||
// TODO: make this configurable
|
||||
if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) {
|
||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
||||
if (batch_tgt.seq_id[t][p] == s) {
|
||||
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
||||
batch_tgt.n_seq_id[t]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drafts[n_seq_cur] = drafts[s];
|
||||
drafts[n_seq_cur].skip = true;
|
||||
// TODO: grammar
|
||||
|
||||
sa.push_back(n_seq_cur);
|
||||
n_seq_cur++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// add drafted token for each sequence
|
||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||
const llama_token id = cur_p.data[is].id;
|
||||
|
||||
int s = sa[is];
|
||||
|
||||
auto & drafted = drafts[s].tokens;
|
||||
//auto & grammar = drafts[s].grammar;
|
||||
|
||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
drafted.push_back(id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
batch_tgt.token [batch_tgt.n_tokens] = id;
|
||||
batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1;
|
||||
batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1;
|
||||
batch_tgt.seq_id [batch_tgt.n_tokens][0] = s;
|
||||
batch_tgt.logits [batch_tgt.n_tokens] = true;
|
||||
|
||||
i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
||||
batch_tgt.n_tokens++;
|
||||
|
||||
// no need to evaluate the last drafted token, since we won't use the result
|
||||
if (i == n_draft - 1) {
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
// add the token to the batch for batched decoding with the draft model
|
||||
batch_dft.token [batch_dft.n_tokens] = id;
|
||||
batch_dft.pos [batch_dft.n_tokens] = n_past_cur;
|
||||
batch_dft.n_seq_id[batch_dft.n_tokens] = 1;
|
||||
batch_dft.seq_id [batch_dft.n_tokens][0] = s;
|
||||
batch_dft.logits [batch_dft.n_tokens] = true;
|
||||
|
||||
i_batch_dft = batch_dft.n_tokens;
|
||||
|
||||
batch_dft.n_tokens++;
|
||||
}
|
||||
}
|
||||
|
||||
// computes softmax and sorts the candidates
|
||||
llama_sample_softmax(ctx_dft, &cur_p);
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str());
|
||||
}
|
||||
|
||||
// TODO: better logic?
|
||||
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
||||
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
|
||||
// no sequence is drafting anymore
|
||||
if (batch_dft.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// drafted token
|
||||
const llama_token id = cur_p.data[0].id;
|
||||
|
||||
drafted.push_back(id);
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
// no need to evaluate the last drafted token, since we won't use the result
|
||||
if (i == n_draft - 1) {
|
||||
break;
|
||||
// update grammar
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].drafting) continue;
|
||||
|
||||
auto & drafted = drafts[s].tokens;
|
||||
auto & grammar = drafts[s].grammar;
|
||||
|
||||
if (grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx_dft, grammar, drafted.back());
|
||||
}
|
||||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
|
||||
++n_past_cur;
|
||||
|
||||
if (grammar_dft != NULL) {
|
||||
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
|
||||
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
if (batch_tgt.n_tokens >= n_draft) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
|
||||
++n_past_tgt;
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop
|
||||
drafted.erase(drafted.begin());
|
||||
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
|
||||
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
||||
}
|
||||
}
|
||||
|
||||
auto t_dec_end = ggml_time_us();
|
||||
|
@ -291,7 +486,6 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
||||
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||
|
||||
// TODO: make sure these numbers are computed correctly
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("n_draft = %d\n", n_draft);
|
||||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
|
@ -305,15 +499,20 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx_tgt);
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
|
||||
llama_free(ctx_tgt);
|
||||
llama_free_model(model_tgt);
|
||||
|
||||
llama_free(ctx_dft);
|
||||
llama_free_model(model_dft);
|
||||
|
||||
if (grammar_dft != NULL) {
|
||||
llama_grammar_free(grammar_dft);
|
||||
llama_grammar_free(grammar_tgt);
|
||||
if (grammar) {
|
||||
llama_grammar_free(grammar);
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
llama_grammar_free(drafts[i].grammar);
|
||||
}
|
||||
}
|
||||
llama_backend_free();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue