speculative : add tree-based sampling support
ggml-ci
This commit is contained in:
parent
5261aee8d8
commit
4de5a2d473
11 changed files with 469 additions and 192 deletions
37
common/log.h
37
common/log.h
|
@ -612,6 +612,43 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
|
||||||
}() \
|
}() \
|
||||||
.c_str()
|
.c_str()
|
||||||
|
|
||||||
|
#define LOG_BATCH_TOSTR_PRETTY(ctx, batch) \
|
||||||
|
[&batch, &ctx]() \
|
||||||
|
{ \
|
||||||
|
std::stringstream buf; \
|
||||||
|
buf << "[ "; \
|
||||||
|
\
|
||||||
|
bool first = true; \
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) \
|
||||||
|
{ \
|
||||||
|
if (!first) \
|
||||||
|
buf << ", "; \
|
||||||
|
else \
|
||||||
|
first = false; \
|
||||||
|
\
|
||||||
|
auto detokenized = llama_token_to_piece(ctx, batch.token[i]); \
|
||||||
|
\
|
||||||
|
detokenized.erase( \
|
||||||
|
std::remove_if( \
|
||||||
|
detokenized.begin(), \
|
||||||
|
detokenized.end(), \
|
||||||
|
[](const unsigned char c) { return !std::isprint(c); }), \
|
||||||
|
detokenized.end()); \
|
||||||
|
\
|
||||||
|
buf \
|
||||||
|
<< "\n" << std::to_string(i) \
|
||||||
|
<< ":token '" << detokenized << "'" \
|
||||||
|
<< ":pos " << std::to_string(batch.pos[i]) \
|
||||||
|
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i]) \
|
||||||
|
<< ":seq_id " << std::to_string(batch.seq_id[i][0]) \
|
||||||
|
<< ":logits " << std::to_string(batch.logits[i]); \
|
||||||
|
} \
|
||||||
|
buf << " ]"; \
|
||||||
|
\
|
||||||
|
return buf.str(); \
|
||||||
|
}() \
|
||||||
|
.c_str()
|
||||||
|
|
||||||
#ifdef LOG_DISABLE_LOGS
|
#ifdef LOG_DISABLE_LOGS
|
||||||
|
|
||||||
#undef LOG
|
#undef LOG
|
||||||
|
|
|
@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_kv_max, 0);
|
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||||
|
|
||||||
// decode in batches of ctx_params.n_batch tokens
|
// decode in batches of ctx_params.n_batch tokens
|
||||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||||
|
@ -126,6 +126,7 @@ int main(int argc, char ** argv) {
|
||||||
batch.token + i,
|
batch.token + i,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
|
@ -148,7 +149,8 @@ int main(int argc, char ** argv) {
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
batch.token[i] = 0;
|
batch.token[i] = 0;
|
||||||
batch.pos[i] = i;
|
batch.pos[i] = i;
|
||||||
batch.seq_id[i] = 0;
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = 0;
|
||||||
batch.logits[i] = false;
|
batch.logits[i] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +181,8 @@ int main(int argc, char ** argv) {
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
batch.token[i] = 0;
|
batch.token[i] = 0;
|
||||||
batch.pos[i] = i;
|
batch.pos[i] = i;
|
||||||
batch.seq_id[i] = 0;
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = 0;
|
||||||
batch.logits[i] = false;
|
batch.logits[i] = false;
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
@ -209,7 +212,8 @@ int main(int argc, char ** argv) {
|
||||||
for (int j = 0; j < pl; ++j) {
|
for (int j = 0; j < pl; ++j) {
|
||||||
batch.token[j] = 0;
|
batch.token[j] = 0;
|
||||||
batch.pos[j] = pp + i;
|
batch.pos[j] = pp + i;
|
||||||
batch.seq_id[j] = j;
|
batch.n_seq_id[j] = 1;
|
||||||
|
batch.seq_id[j][0] = j;
|
||||||
batch.logits[j] = true;
|
batch.logits[j] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,10 +97,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
|
|
||||||
// create a llama_batch with size 512
|
// create a llama_batch
|
||||||
// we use this object to submit token data for decoding
|
// we use this object to submit token data for decoding
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
batch.n_tokens = tokens_list.size();
|
batch.n_tokens = tokens_list.size();
|
||||||
|
@ -201,8 +201,9 @@ int main(int argc, char ** argv) {
|
||||||
// push this new token for next evaluation
|
// push this new token for next evaluation
|
||||||
batch.token [batch.n_tokens] = new_token_id;
|
batch.token [batch.n_tokens] = new_token_id;
|
||||||
batch.pos [batch.n_tokens] = n_cur;
|
batch.pos [batch.n_tokens] = n_cur;
|
||||||
batch.seq_id[batch.n_tokens] = i;
|
batch.n_seq_id[batch.n_tokens] = 1;
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.seq_id [batch.n_tokens][0] = i;
|
||||||
|
batch.logits [batch.n_tokens] = true;
|
||||||
|
|
||||||
i_batch[i] = batch.n_tokens;
|
i_batch[i] = batch.n_tokens;
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int
|
||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||||
if (llama_decode(ctx_llama, batch)) {
|
if (llama_decode(ctx_llama, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -120,7 +120,7 @@ int main(int argc, char ** argv) {
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
|
|
||||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past);
|
||||||
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
||||||
|
|
|
@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0);
|
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||||
|
|
||||||
int32_t n_total_prompt = 0;
|
int32_t n_total_prompt = 0;
|
||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
|
@ -190,7 +190,8 @@ int main(int argc, char ** argv) {
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
batch.token[i] = tokens_system[i];
|
batch.token[i] = tokens_system[i];
|
||||||
batch.pos[i] = i;
|
batch.pos[i] = i;
|
||||||
batch.seq_id[i] = 0;
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = 0;
|
||||||
batch.logits[i] = false;
|
batch.logits[i] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,8 +221,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = client.sampled;
|
batch.token [batch.n_tokens] = client.sampled;
|
||||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||||
batch.seq_id[batch.n_tokens] = client.id;
|
batch.n_seq_id[batch.n_tokens] = 1;
|
||||||
batch.logits[batch.n_tokens] = true;
|
batch.seq_id [batch.n_tokens][0] = client.id;
|
||||||
|
batch.logits [batch.n_tokens] = true;
|
||||||
|
|
||||||
client.n_decoded += 1;
|
client.n_decoded += 1;
|
||||||
client.i_batch = batch.n_tokens;
|
client.i_batch = batch.n_tokens;
|
||||||
|
@ -260,8 +262,9 @@ int main(int argc, char ** argv) {
|
||||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||||
batch.seq_id[batch.n_tokens] = client.id;
|
batch.n_seq_id[batch.n_tokens] = client.id;
|
||||||
batch.logits[batch.n_tokens] = false;
|
batch.seq_id [batch.n_tokens][0] = client.id;
|
||||||
|
batch.logits [batch.n_tokens] = false;
|
||||||
batch.n_tokens += 1;
|
batch.n_tokens += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,6 +311,7 @@ int main(int argc, char ** argv) {
|
||||||
batch.token + i,
|
batch.token + i,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
|
|
|
@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
|
||||||
// create a llama_batch with size 512
|
// create a llama_batch with size 512
|
||||||
// we use this object to submit token data for decoding
|
// we use this object to submit token data for decoding
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(512, 0);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
batch.n_tokens = tokens_list.size();
|
batch.n_tokens = tokens_list.size();
|
||||||
|
|
|
@ -10,9 +10,19 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
struct seq_draft {
|
struct seq_draft {
|
||||||
|
bool active = false;
|
||||||
|
bool drafting = false;
|
||||||
|
bool skip = false;
|
||||||
|
|
||||||
|
int i_batch_dft = 0;
|
||||||
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
|
||||||
struct llama_grammar * grammar = NULL;
|
struct llama_grammar * grammar = NULL;
|
||||||
|
|
||||||
|
std::vector<llama_token> last_tokens;
|
||||||
|
struct llama_sampling_context ctx_sampling;
|
||||||
};
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
@ -27,6 +37,9 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
|
int n_seq_dft = 8;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
|
@ -97,25 +110,11 @@ int main(int argc, char ** argv) {
|
||||||
int n_past_tgt = inp.size();
|
int n_past_tgt = inp.size();
|
||||||
int n_past_dft = inp.size();
|
int n_past_dft = inp.size();
|
||||||
|
|
||||||
std::vector<llama_token> drafted;
|
|
||||||
|
|
||||||
std::vector<llama_token> last_tokens(n_ctx);
|
|
||||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
|
||||||
|
|
||||||
for (auto & id : inp) {
|
|
||||||
last_tokens.erase(last_tokens.begin());
|
|
||||||
last_tokens.push_back(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
// used to determine end of generation
|
// used to determine end of generation
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
// grammar stuff
|
// grammar stuff
|
||||||
struct llama_grammar * grammar_dft = NULL;
|
struct llama_grammar * grammar = NULL;
|
||||||
struct llama_grammar * grammar_tgt = NULL;
|
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar;
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
|
||||||
|
@ -128,21 +127,69 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
|
// target model sampling context
|
||||||
|
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||||
|
|
||||||
|
// TODO: move to llama_sampling_state
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
std::vector<llama_token> last_tokens;
|
||||||
|
last_tokens.resize(n_ctx);
|
||||||
|
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||||
|
|
||||||
|
for (auto & id : inp) {
|
||||||
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
last_tokens.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// draft sequence data
|
||||||
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
{
|
||||||
|
auto & last_tokens = drafts[i].last_tokens;
|
||||||
|
|
||||||
|
last_tokens.resize(n_ctx);
|
||||||
|
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||||
|
|
||||||
|
for (auto & id : inp) {
|
||||||
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
last_tokens.push_back(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch batch_dft = llama_batch_init(512, 0, 1);
|
||||||
|
llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft);
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
drafts[0].i_batch_tgt.resize(1);
|
||||||
|
drafts[0].i_batch_tgt[0] = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
if (!drafts[i].active) continue;
|
||||||
|
|
||||||
|
const auto & tokens = drafts[i].tokens;
|
||||||
|
|
||||||
|
LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens));
|
||||||
|
}
|
||||||
|
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
|
int i_keep = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
|
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
last_tokens.erase(last_tokens.begin());
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
@ -151,6 +198,7 @@ int main(int argc, char ** argv) {
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
||||||
|
|
||||||
printf("%s", token_str.c_str());
|
printf("%s", token_str.c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
@ -160,9 +208,24 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
++n_predict;
|
++n_predict;
|
||||||
|
|
||||||
// check if the draft matches the target
|
// check if the target token matches any of the drafts
|
||||||
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
|
{
|
||||||
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
|
bool matches = false;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
if (!drafts[i].active) continue;
|
||||||
|
|
||||||
|
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
||||||
|
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
||||||
|
|
||||||
|
i_keep = i;
|
||||||
|
matches = true;
|
||||||
|
} else {
|
||||||
|
drafts[i].active = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matches) {
|
||||||
++n_accept;
|
++n_accept;
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
@ -170,44 +233,47 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// the drafted token was rejected or we are out of drafted tokens
|
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
||||||
|
|
||||||
if (i_dft < (int) drafted.size()) {
|
// TODO: simplify
|
||||||
LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
|
{
|
||||||
i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
|
LOG("keeping sequence %d\n", i_keep);
|
||||||
} else {
|
|
||||||
LOG("out of drafted tokens\n");
|
llama_kv_cache_seq_keep(ctx_dft, i_keep);
|
||||||
|
llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1);
|
||||||
|
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1);
|
||||||
|
llama_kv_cache_seq_keep(ctx_tgt, i_keep);
|
||||||
|
llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1);
|
||||||
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
drafts[i].active = false;
|
||||||
|
drafts[i].tokens.clear();
|
||||||
|
drafts[i].i_batch_tgt.clear();
|
||||||
|
}
|
||||||
|
// note: will be erased after the speculation phase
|
||||||
|
drafts[0].tokens.push_back(id);
|
||||||
|
drafts[0].i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
|
{
|
||||||
|
batch_dft.n_tokens = 1;
|
||||||
|
|
||||||
|
batch_dft.token[0] = id;
|
||||||
|
batch_dft.pos[0] = n_past_dft;
|
||||||
|
batch_dft.n_seq_id[0] = 1;
|
||||||
|
batch_dft.seq_id[0][0] = 0;
|
||||||
|
batch_dft.logits[0] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
|
llama_decode(ctx_dft, batch_dft);
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
// heuristic for n_draft
|
|
||||||
{
|
|
||||||
const int n_draft_cur = (int) drafted.size();
|
|
||||||
const bool all_accepted = i_dft == n_draft_cur;
|
|
||||||
|
|
||||||
LOG("n_draft = %d\n", n_draft);
|
|
||||||
LOG("n_draft_cur = %d\n", n_draft_cur);
|
|
||||||
LOG("i_dft = %d\n", i_dft);
|
|
||||||
LOG("all_accepted = %d\n", all_accepted);
|
|
||||||
|
|
||||||
if (all_accepted && n_draft == n_draft_cur) {
|
|
||||||
LOG(" - max drafted tokens accepted - n_draft += 8\n");
|
|
||||||
n_draft = std::min(30, n_draft + 8);
|
|
||||||
} else if (all_accepted) {
|
|
||||||
LOG(" - partially drafted tokens accepted - no change\n");
|
|
||||||
} else {
|
|
||||||
LOG(" - drafted token rejected - n_draft -= 1\n");
|
|
||||||
n_draft = std::max(2, n_draft - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drafted.clear();
|
|
||||||
drafted.push_back(id);
|
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,21 +281,54 @@ int main(int argc, char ** argv) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (grammar_tgt) {
|
if (grammar) {
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
auto * grammar_dft = drafts[i].grammar;
|
||||||
if (grammar_dft) {
|
if (grammar_dft) {
|
||||||
llama_grammar_free(grammar_dft);
|
llama_grammar_free(grammar_dft);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
||||||
|
|
||||||
LOG("copied target grammar to draft grammar\n");
|
LOG("copied target grammar to draft %d grammar\n", i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample n_draft tokens from the draft model using greedy decoding
|
int n_seq_cur = 1;
|
||||||
int n_past_cur = n_past_dft;
|
int n_past_cur = n_past_dft;
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
|
||||||
float * logits = llama_get_logits(ctx_dft);
|
|
||||||
|
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
drafts[i].active = false;
|
||||||
|
drafts[i].drafting = false;
|
||||||
|
}
|
||||||
|
drafts[0].active = true;
|
||||||
|
drafts[0].drafting = true;
|
||||||
|
drafts[0].i_batch_dft = 0;
|
||||||
|
|
||||||
|
batch_tgt.n_tokens = 1;
|
||||||
|
batch_tgt.token[0] = drafts[0].tokens[0];
|
||||||
|
batch_tgt.pos[0] = n_past_tgt;
|
||||||
|
batch_tgt.n_seq_id[0] = 1;
|
||||||
|
batch_tgt.seq_id[0][0] = 0;
|
||||||
|
batch_tgt.logits[0] = true;
|
||||||
|
|
||||||
|
// sample n_draft tokens from the draft model using tree-based sampling
|
||||||
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
|
batch_dft.n_tokens = 0;
|
||||||
|
|
||||||
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
drafts[s].skip = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
if (!drafts[s].drafting || drafts[s].skip) continue;
|
||||||
|
|
||||||
|
auto & grammar = drafts[s].grammar;
|
||||||
|
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||||
|
|
||||||
|
float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft);
|
||||||
|
|
||||||
|
// TODO: optimize
|
||||||
candidates.clear();
|
candidates.clear();
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
@ -237,51 +336,147 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
if (grammar_dft != NULL) {
|
if (grammar != NULL) {
|
||||||
llama_sample_grammar(ctx_dft, &cur_p, grammar_dft);
|
llama_sample_grammar(ctx_dft, &cur_p, grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
// computes softmax and sorts the candidates
|
// computes softmax and sorts the candidates
|
||||||
llama_sample_softmax(ctx_dft, &cur_p);
|
llama_sample_softmax(ctx_dft, &cur_p);
|
||||||
|
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int k = 0; k < 3; ++k) {
|
||||||
LOG(" - draft candidate %3d: %6d (%8.3f) '%s'\n", i, cur_p.data[i].id, cur_p.data[i].p, llama_token_to_piece(ctx_dft, cur_p.data[i].id).c_str());
|
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
|
k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: better logic?
|
// TODO: make this configurable
|
||||||
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
if (cur_p.data[0].p < 0.1) {
|
||||||
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
|
//if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
||||||
|
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p);
|
||||||
|
drafts[s].drafting = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> sa(1, s);
|
||||||
|
|
||||||
|
for (int f = 1; f < 8; ++f) {
|
||||||
|
// TODO: make this configurable
|
||||||
|
if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) {
|
||||||
|
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||||
|
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||||
|
|
||||||
|
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||||
|
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
||||||
|
if (batch_tgt.seq_id[t][p] == s) {
|
||||||
|
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
||||||
|
batch_tgt.n_seq_id[t]++;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// drafted token
|
drafts[n_seq_cur] = drafts[s];
|
||||||
const llama_token id = cur_p.data[0].id;
|
drafts[n_seq_cur].skip = true;
|
||||||
|
// TODO: grammar
|
||||||
|
|
||||||
|
sa.push_back(n_seq_cur);
|
||||||
|
n_seq_cur++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add drafted token for each sequence
|
||||||
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||||
|
const llama_token id = cur_p.data[is].id;
|
||||||
|
|
||||||
|
int s = sa[is];
|
||||||
|
|
||||||
|
auto & drafted = drafts[s].tokens;
|
||||||
|
//auto & grammar = drafts[s].grammar;
|
||||||
|
|
||||||
|
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||||
|
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
drafted.push_back(id);
|
drafted.push_back(id);
|
||||||
++n_drafted;
|
|
||||||
|
// add unique drafted tokens to the target batch
|
||||||
|
batch_tgt.token [batch_tgt.n_tokens] = id;
|
||||||
|
batch_tgt.pos [batch_tgt.n_tokens] = n_past_tgt + i + 1;
|
||||||
|
batch_tgt.n_seq_id[batch_tgt.n_tokens] = 1;
|
||||||
|
batch_tgt.seq_id [batch_tgt.n_tokens][0] = s;
|
||||||
|
batch_tgt.logits [batch_tgt.n_tokens] = true;
|
||||||
|
|
||||||
|
i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
|
|
||||||
|
batch_tgt.n_tokens++;
|
||||||
|
|
||||||
// no need to evaluate the last drafted token, since we won't use the result
|
// no need to evaluate the last drafted token, since we won't use the result
|
||||||
if (i == n_draft - 1) {
|
if (i == n_draft - 1) {
|
||||||
|
drafts[s].drafting = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the token to the batch for batched decoding with the draft model
|
||||||
|
batch_dft.token [batch_dft.n_tokens] = id;
|
||||||
|
batch_dft.pos [batch_dft.n_tokens] = n_past_cur;
|
||||||
|
batch_dft.n_seq_id[batch_dft.n_tokens] = 1;
|
||||||
|
batch_dft.seq_id [batch_dft.n_tokens][0] = s;
|
||||||
|
batch_dft.logits [batch_dft.n_tokens] = true;
|
||||||
|
|
||||||
|
i_batch_dft = batch_dft.n_tokens;
|
||||||
|
|
||||||
|
batch_dft.n_tokens++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// no sequence is drafting anymore
|
||||||
|
if (batch_dft.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the drafted token on the draft model
|
// evaluate the drafted tokens on the draft model
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
|
llama_decode(ctx_dft, batch_dft);
|
||||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
|
|
||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
|
++n_drafted;
|
||||||
|
|
||||||
if (grammar_dft != NULL) {
|
// update grammar
|
||||||
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
if (!drafts[s].drafting) continue;
|
||||||
|
|
||||||
|
auto & drafted = drafts[s].tokens;
|
||||||
|
auto & grammar = drafts[s].grammar;
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_accept_token(ctx_dft, grammar, drafted.back());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||||
|
if (batch_tgt.n_tokens >= n_draft) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
|
{
|
||||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
++n_past_tgt;
|
for (int s = 1; s < n_seq_dft; ++s) {
|
||||||
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||||
|
}
|
||||||
|
|
||||||
// the first token is always proposed by the traget model before the speculation loop
|
//LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt));
|
||||||
drafted.erase(drafted.begin());
|
llama_decode(ctx_tgt, batch_tgt);
|
||||||
|
++n_past_tgt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
if (!drafts[i].active) continue;
|
||||||
|
|
||||||
|
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto t_dec_end = ggml_time_us();
|
auto t_dec_end = ggml_time_us();
|
||||||
|
@ -291,7 +486,6 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
||||||
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||||
|
|
||||||
// TODO: make sure these numbers are computed correctly
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
LOG_TEE("n_draft = %d\n", n_draft);
|
LOG_TEE("n_draft = %d\n", n_draft);
|
||||||
LOG_TEE("n_predict = %d\n", n_predict);
|
LOG_TEE("n_predict = %d\n", n_predict);
|
||||||
|
@ -305,15 +499,20 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("\ntarget:\n");
|
LOG_TEE("\ntarget:\n");
|
||||||
llama_print_timings(ctx_tgt);
|
llama_print_timings(ctx_tgt);
|
||||||
|
|
||||||
|
llama_batch_free(batch_dft);
|
||||||
|
|
||||||
llama_free(ctx_tgt);
|
llama_free(ctx_tgt);
|
||||||
llama_free_model(model_tgt);
|
llama_free_model(model_tgt);
|
||||||
|
|
||||||
llama_free(ctx_dft);
|
llama_free(ctx_dft);
|
||||||
llama_free_model(model_dft);
|
llama_free_model(model_dft);
|
||||||
|
|
||||||
if (grammar_dft != NULL) {
|
if (grammar) {
|
||||||
llama_grammar_free(grammar_dft);
|
llama_grammar_free(grammar);
|
||||||
llama_grammar_free(grammar_tgt);
|
|
||||||
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
|
llama_grammar_free(drafts[i].grammar);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
|
63
llama.cpp
63
llama.cpp
|
@ -1447,7 +1447,10 @@ static bool llama_kv_cache_find_slot(
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
cache.cells[cache.head + i].pos = batch.pos[i];
|
cache.cells[cache.head + i].pos = batch.pos[i];
|
||||||
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
|
|
||||||
|
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||||
|
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -1527,6 +1530,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
if (new_head == cache.size) new_head = i;
|
if (new_head == cache.size) new_head = i;
|
||||||
|
} else {
|
||||||
|
cache.cells[i].seq_id.clear();
|
||||||
|
cache.cells[i].seq_id.insert(seq_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3080,7 +3086,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -3466,7 +3472,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -3865,7 +3871,7 @@ static struct ggml_cgraph * llm_build_refact(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -4217,7 +4223,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -4569,7 +4575,7 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -4800,7 +4806,7 @@ static struct ggml_cgraph * llm_build_persimmon(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||||
|
@ -5198,7 +5204,7 @@ static struct ggml_cgraph * llm_build_bloom(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -5466,7 +5472,7 @@ static struct ggml_cgraph * llm_build_mpt(
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const llama_pos pos = batch.pos[j];
|
const llama_pos pos = batch.pos[j];
|
||||||
const llama_seq_id seq_id = batch.seq_id[j];
|
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
||||||
|
@ -5768,7 +5774,10 @@ static int llama_decode_internal(
|
||||||
// helpers for smoother batch API transistion
|
// helpers for smoother batch API transistion
|
||||||
// after deprecating the llama_eval calls, these will be removed
|
// after deprecating the llama_eval calls, these will be removed
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<llama_seq_id> seq_id;
|
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id *> seq_id_arr;
|
||||||
|
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||||
|
|
||||||
if (batch.pos == nullptr) {
|
if (batch.pos == nullptr) {
|
||||||
pos.resize(n_tokens);
|
pos.resize(n_tokens);
|
||||||
|
@ -5780,12 +5789,18 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.seq_id == nullptr) {
|
if (batch.seq_id == nullptr) {
|
||||||
|
n_seq_id.resize(n_tokens);
|
||||||
seq_id.resize(n_tokens);
|
seq_id.resize(n_tokens);
|
||||||
|
seq_id_arr.resize(n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
seq_id[i] = batch.all_seq_id;
|
n_seq_id[i] = 1;
|
||||||
|
seq_id[i].resize(1);
|
||||||
|
seq_id[i][0] = batch.all_seq_id;
|
||||||
|
seq_id_arr[i] = seq_id[i].data();
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.seq_id = seq_id.data();
|
batch.n_seq_id = n_seq_id.data();
|
||||||
|
batch.seq_id = seq_id_arr.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
|
@ -8837,6 +8852,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
|
if (seq_id_src == seq_id_dst) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9289,7 +9307,7 @@ int llama_eval_embd(
|
||||||
int n_past) {
|
int n_past) {
|
||||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||||
|
|
||||||
const int ret = llama_decode_internal(*ctx, batch);
|
const int ret = llama_decode_internal(*ctx, batch);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
|
@ -9314,6 +9332,7 @@ struct llama_batch llama_batch_get_one(
|
||||||
/*tokens =*/ tokens,
|
/*tokens =*/ tokens,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
|
/*n_seq_id =*/ nullptr,
|
||||||
/*seq_id =*/ nullptr,
|
/*seq_id =*/ nullptr,
|
||||||
/*logits =*/ nullptr,
|
/*logits =*/ nullptr,
|
||||||
/*all_pos_0 =*/ pos_0,
|
/*all_pos_0 =*/ pos_0,
|
||||||
|
@ -9322,8 +9341,8 @@ struct llama_batch llama_batch_get_one(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
|
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
|
||||||
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
||||||
|
|
||||||
if (embd) {
|
if (embd) {
|
||||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
||||||
|
@ -9332,7 +9351,11 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||||
batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
|
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
||||||
|
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
|
}
|
||||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
|
||||||
return batch;
|
return batch;
|
||||||
|
@ -9342,7 +9365,13 @@ void llama_batch_free(struct llama_batch batch) {
|
||||||
if (batch.token) free(batch.token);
|
if (batch.token) free(batch.token);
|
||||||
if (batch.embd) free(batch.embd);
|
if (batch.embd) free(batch.embd);
|
||||||
if (batch.pos) free(batch.pos);
|
if (batch.pos) free(batch.pos);
|
||||||
if (batch.seq_id) free(batch.seq_id);
|
if (batch.n_seq_id) free(batch.n_seq_id);
|
||||||
|
if (batch.seq_id) {
|
||||||
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
free(batch.seq_id[i]);
|
||||||
|
}
|
||||||
|
free(batch.seq_id);
|
||||||
|
}
|
||||||
if (batch.logits) free(batch.logits);
|
if (batch.logits) free(batch.logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
9
llama.h
9
llama.h
|
@ -136,7 +136,8 @@ extern "C" {
|
||||||
llama_token * token;
|
llama_token * token;
|
||||||
float * embd;
|
float * embd;
|
||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
llama_seq_id * seq_id;
|
int32_t * n_seq_id;
|
||||||
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits;
|
int8_t * logits;
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||||
|
@ -446,7 +447,8 @@ extern "C" {
|
||||||
llama_pos pos_0,
|
llama_pos pos_0,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Allocates a batch of tokens on the heap
|
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||||
|
// Each token can be assigned up to n_seq_max sequence ids
|
||||||
// The batch has to be freed with llama_batch_free()
|
// The batch has to be freed with llama_batch_free()
|
||||||
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
|
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
|
||||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||||
|
@ -454,7 +456,8 @@ extern "C" {
|
||||||
// All members are left uninitialized
|
// All members are left uninitialized
|
||||||
LLAMA_API struct llama_batch llama_batch_init(
|
LLAMA_API struct llama_batch llama_batch_init(
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t embd);
|
int32_t embd,
|
||||||
|
int32_t n_seq_max);
|
||||||
|
|
||||||
// Frees a batch of tokens allocated with llama_batch_init()
|
// Frees a batch of tokens allocated with llama_batch_init()
|
||||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue