(WIP) Implement stochastic speculative decoding
This commit is contained in:
parent
6560bed3f0
commit
c1bad4a549
3 changed files with 243 additions and 43 deletions
|
@ -295,6 +295,76 @@ static llama_token llama_sampling_sample_impl(
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_token_data_array llama_sample_probability_distribution_impl(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
const int idx) {
|
||||||
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
const float temp = params.temp;
|
||||||
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
|
const float penalty_freq = params.penalty_freq;
|
||||||
|
const float penalty_present = params.penalty_present;
|
||||||
|
const int mirostat = params.mirostat;
|
||||||
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
|
auto & prev = ctx_sampling->prev;
|
||||||
|
auto & cur = ctx_sampling->cur;
|
||||||
|
|
||||||
|
// Get a pointer to the logits
|
||||||
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
|
// Declare original_logits at the beginning of the function scope
|
||||||
|
std::vector<float> original_logits;
|
||||||
|
|
||||||
|
// apply params.logit_bias map
|
||||||
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
|
logits[it->first] += it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx_cfg) {
|
||||||
|
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||||
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.clear();
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
|
// apply penalties
|
||||||
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
|
if (penalty_tokens_used_size) {
|
||||||
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
|
||||||
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
|
||||||
|
if (!penalize_nl) {
|
||||||
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||||
|
cur_p.data[idx].logit = nl_logit;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sample_softmax(ctx_main, &cur_p);
|
||||||
|
return cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
@ -304,6 +374,14 @@ llama_token llama_sampling_sample(
|
||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token_data_array llama_sampling_probability_distribution(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
const int idx) {
|
||||||
|
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
|
|
@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = 0);
|
||||||
|
|
||||||
|
// returns the probability that token of given id will be sampled
|
||||||
|
llama_token_data_array llama_sampling_probability_distribution(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
int idx = 0);
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
|
|
@ -18,6 +18,7 @@ struct seq_draft {
|
||||||
std::vector<int> i_batch_tgt;
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
std::vector<llama_token_data_array> dist;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling;
|
struct llama_sampling_context * ctx_sampling;
|
||||||
};
|
};
|
||||||
|
@ -166,7 +167,6 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
@ -196,48 +196,149 @@ int main(int argc, char ** argv) {
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
int s_keep = 0;
|
int s_keep = 0;
|
||||||
|
|
||||||
|
llama_token token_id;
|
||||||
|
std::string token_str;
|
||||||
|
|
||||||
|
// loop until we fail to accept a drafted token or we run out of drafted tokens
|
||||||
while (true) {
|
while (true) {
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
||||||
|
|
||||||
// sample from the target model
|
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
|
||||||
|
|
||||||
if (!params.use_color) {
|
|
||||||
printf("%s", token_str.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (id == llama_token_eos(model_tgt)) {
|
|
||||||
has_eos = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
++n_predict;
|
|
||||||
|
|
||||||
// check if the target token matches any of the drafts
|
// check if the target token matches any of the drafts
|
||||||
|
// for stochastic sampling, attempt to match the token with the drafted tokens
|
||||||
{
|
{
|
||||||
bool matches = false;
|
bool accept = false;
|
||||||
|
if (params.sparams.temp > 0) {
|
||||||
|
// stochastic verification
|
||||||
|
|
||||||
|
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
|
float p_tgt, p_dft;
|
||||||
|
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||||
|
|
||||||
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
if (!drafts[s].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (i_dft >= (int) drafts[s].tokens.size()) {
|
||||||
|
drafts[s].active = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (accept) {
|
||||||
|
// if we already accepted a token, we can skip the rest
|
||||||
|
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
||||||
|
drafts[s].active = false;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float r = rand() / (float) RAND_MAX;
|
||||||
|
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
|
||||||
|
// acquire the probability of the token from the draft model
|
||||||
|
for (int i = 0; i < dist_tgt.size; i++) {
|
||||||
|
|
||||||
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||||
|
p_tgt = dist_tgt.data[i].p;
|
||||||
|
}
|
||||||
|
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
||||||
|
p_dft = dist_dft.data[i].p;
|
||||||
|
}
|
||||||
|
if (p_tgt && p_dft) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
||||||
|
if (r <= p_tgt / p_dft) {
|
||||||
|
s_keep = s;
|
||||||
|
accept = true;
|
||||||
|
token_id = drafts[s].tokens[i_dft];
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
||||||
|
drafts[s].active = false;
|
||||||
|
|
||||||
|
// calculate residual probability
|
||||||
|
GGML_ASSERT(dist_tgt.sorted);
|
||||||
|
GGML_ASSERT(dist_dft.sorted);
|
||||||
|
float sum_probs = 0.0f;
|
||||||
|
|
||||||
|
// sort dist by id
|
||||||
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.id < b.id;
|
||||||
|
});
|
||||||
|
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.id < b.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (int i = 0; i < dist_tgt.size; i++) {
|
||||||
|
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
||||||
|
sum_probs += dist_tgt.data[i].p;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < dist_tgt.size; i++) {
|
||||||
|
dist_tgt.data[i].p /= sum_probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort dist_tgt by p desc
|
||||||
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.p > b.p;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = s; i < n_seq_dft; i++) {
|
||||||
|
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||||
|
// synchronize active status for sequences with the same drafted token
|
||||||
|
drafts[i].active = drafts[i].active & accept;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
|
||||||
if (!drafts[s].active) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
if (!accept) {
|
||||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
// all drafted tokens were rejected
|
||||||
|
// sample from the target model
|
||||||
|
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
}
|
||||||
|
|
||||||
s_keep = s;
|
|
||||||
matches = true;
|
} else {
|
||||||
} else {
|
// greedy verification
|
||||||
drafts[s].active = false;
|
|
||||||
|
// sample from the target model
|
||||||
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
|
||||||
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
if (!drafts[s].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
||||||
|
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||||
|
|
||||||
|
s_keep = s;
|
||||||
|
accept = true;
|
||||||
|
} else {
|
||||||
|
drafts[s].active = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (matches) {
|
if (token_id == llama_token_eos(model_tgt)) {
|
||||||
|
has_eos = true;
|
||||||
|
}
|
||||||
|
++n_predict;
|
||||||
|
|
||||||
|
if (accept) {
|
||||||
++n_accept;
|
++n_accept;
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
@ -245,17 +346,21 @@ int main(int argc, char ** argv) {
|
||||||
if (params.use_color) {
|
if (params.use_color) {
|
||||||
// Color token according to its origin sequence
|
// Color token according to its origin sequence
|
||||||
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
||||||
fflush(stdout);
|
} else {
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
}
|
}
|
||||||
|
fflush(stdout);
|
||||||
continue;
|
continue;
|
||||||
|
} else {
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params.use_color) {
|
}
|
||||||
printf("%s", token_str.c_str());
|
|
||||||
}
|
|
||||||
fflush(stdout);
|
|
||||||
|
|
||||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
{
|
||||||
|
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
||||||
|
|
||||||
// TODO: simplify
|
// TODO: simplify
|
||||||
{
|
{
|
||||||
|
@ -275,21 +380,25 @@ int main(int argc, char ** argv) {
|
||||||
drafts[s].active = false;
|
drafts[s].active = false;
|
||||||
drafts[s].tokens.clear();
|
drafts[s].tokens.clear();
|
||||||
drafts[s].i_batch_tgt.clear();
|
drafts[s].i_batch_tgt.clear();
|
||||||
|
// free dist and clear
|
||||||
|
for (int i = 0; i < drafts[s].dist.size(); i++) {
|
||||||
|
free(drafts[s].dist[i].data);
|
||||||
|
}
|
||||||
|
drafts[s].dist.clear();
|
||||||
}
|
}
|
||||||
// note: will be erased after the speculation phase
|
// note: will be erased after the speculation phase
|
||||||
drafts[0].tokens.push_back(id);
|
drafts[0].tokens.push_back(token_id);
|
||||||
|
drafts[0].dist.push_back(llama_token_data_array{});
|
||||||
drafts[0].i_batch_tgt.push_back(0);
|
drafts[0].i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
llama_batch_clear(batch_dft);
|
llama_batch_clear(batch_dft);
|
||||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||||
llama_decode (ctx_dft, batch_dft);
|
llama_decode(ctx_dft, batch_dft);
|
||||||
|
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if (n_predict > params.n_predict || has_eos) {
|
||||||
|
@ -367,6 +476,7 @@ int main(int argc, char ** argv) {
|
||||||
drafts[n_seq_cur].skip = true;
|
drafts[n_seq_cur].skip = true;
|
||||||
|
|
||||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||||
|
drafts[n_seq_cur].dist = drafts[s].dist;
|
||||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
|
@ -389,6 +499,10 @@ int main(int argc, char ** argv) {
|
||||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
|
// save cur_p into drafts[s].dist
|
||||||
|
llama_token_data *data = (llama_token_data *)malloc(sizeof(llama_token_data) * cur_p.size());
|
||||||
|
memcpy(data, cur_p.data(), sizeof(llama_token_data) * cur_p.size());
|
||||||
|
drafts[s].dist.push_back(llama_token_data_array{data, cur_p.size(), true});
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
|
@ -440,6 +554,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||||
|
drafts[s].dist.erase(drafts[s].dist.begin());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue