speculative : simplify the implementation (#10504)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-26 12:29:38 +02:00 committed by GitHub
parent 9a4b79bcfa
commit 811872a59d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
llama_token id_last = inp.back(); llama_token id_last = inp.back();
// all tokens currently in the target context // all tokens currently in the target context
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1); llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
int n_past = inp.size() - 1; int n_past = inp.size() - 1;
@ -181,29 +182,26 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
n_past += ids.size() - 1; n_past += ids.size() - 1;
n_drafted += batch_tgt.n_tokens - 1; n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1; n_accept += ids.size() - 1;
n_predict += ids.size();
// process the accepted tokens and update contexts // process the accepted tokens and update contexts
// //
// this is the standard token post-processing that we normally do // this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once // in this case, we do it for a group of accepted tokens at once
// //
{
llama_token id;
std::string token_str;
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
id = ids[i]; prompt_tgt.push_back(id_last);
++n_predict; id_last = ids[i];
if (llama_token_is_eog(model_tgt, id)) { if (llama_token_is_eog(model_tgt, id_last)) {
has_eos = true; has_eos = true;
break; break;
} }
token_str = common_token_to_piece(ctx_tgt, id); const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
if (params.use_color && i + 1 < ids.size()) { if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
@ -212,11 +210,7 @@ int main(int argc, char ** argv) {
} }
} }
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
break;
}
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
{ {
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
@ -224,11 +218,8 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
} }
prompt_tgt.push_back(id_last); if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1); break;
// remember the last accepted token for the next iteration
id_last = id;
} }
} }