This commit is contained in:
Johannes Gäßler 2024-08-30 21:08:15 -05:00 committed by GitHub
commit 408c8402b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 315 additions and 113 deletions

View file

@ -52,52 +52,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()]; return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
} }
// If sample size or percentage are below these thresholds the draft is aborted early: // Sample size and percentage must meet these thresholds to be added to the draft tree:
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1}; constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1};
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50}; constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10};
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2}; constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66}; constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
struct draft_candidate {
llama_draft_t draft;
float nll;
int nsampled;
};
struct compare_draft_candidate {
bool operator()(const draft_candidate & a, const draft_candidate & b){
if (a.nsampled > b.nsampled) {
return true;
}
if (a.nsampled < b.nsampled) {
return false;
}
return a.nll < b.nll;
}
};
// Helper function that tries to draft tokens from only the static ngram cache:
static void try_draft(
llama_ngram_cache & nc_static, const llama_ngram & ngram_static,
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft.size() - 1);
if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1)/2) {
return;
}
// Helper function that tries to draft a token from only the static ngram cache:
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
if (part_static_it == nc_static.end()) { if (part_static_it == nc_static.end()) {
return -1; return;
} }
const llama_ngram_cache_part part_static = part_static_it->second; const llama_ngram_cache_part part_static = part_static_it->second;
int max_count_static = 0;
int sum_count_static = 0; int sum_count_static = 0;
llama_token max_token = -1;
for (std::pair<llama_token, int> token_count_static : part_static) {
const int32_t count_static = token_count_static.second;
sum_count_static += count_static;
}
for (std::pair<llama_token, int> token_count_static : part_static) { for (std::pair<llama_token, int> token_count_static : part_static) {
const llama_token token = token_count_static.first; const llama_token token = token_count_static.first;
const int32_t count_static = token_count_static.second; const int32_t count_static = token_count_static.second;
if (count_static > max_count_static) { if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1]) {
max_token = token; continue;
max_count_static = count_static; }
if (100*count_static < min_percent[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
continue;;
} }
sum_count_static += count_static;
}
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { draft_candidate cc;
return -1; for (const llama_token & t : cp.draft) {
cc.draft.push_back(t);
}
cc.draft.push_back(token);
cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static);
cc.nsampled = nsc;
bool duplicate = false;
for (const draft_candidate & co : drafts_new) {
if (co.draft == cc.draft) {
duplicate = true;
break;
}
}
if (duplicate) {
continue;
}
drafts_new.push_back(cc);
} }
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
return -1;
}
return max_token;
} }
// Try to draft a token from primary cache (context/dynamic), validate with static cache: // Try to draft tokens from primary cache (context/dynamic), validate with static cache:
static llama_token try_draft( static void try_draft(
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static, llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
const int * min_sample_size, const int * min_percent) { const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
llama_token drafted_token = -1; for (int i = ngrams_primary.size()-1; i >= 0; --i) {
const int nsc = (ngram_min + i) - (cp.draft.size() - 1);
if (nsc < (ngram_min + i + 1)/2) {
break;
}
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
const llama_ngram ngram_primary = ngrams_primary[i]; const llama_ngram ngram_primary = ngrams_primary[i];
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
@ -106,10 +155,8 @@ static llama_token try_draft(
} }
const llama_ngram_cache_part part_primary = part_primary_it->second; const llama_ngram_cache_part part_primary = part_primary_it->second;
int max_count_primary = 0;
int max_count_static = 0;
int sum_count_primary = 0; int sum_count_primary = 0;
llama_token max_token = -1; int sum_count_prod = 0;
for (std::pair<llama_token, int> token_count_primary : part_primary) { for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first; const llama_token token = token_count_primary.first;
@ -119,44 +166,100 @@ static llama_token try_draft(
const int32_t count_primary = token_count_primary.second; const int32_t count_primary = token_count_primary.second;
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1; const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
if (count_primary*count_static > max_count_primary*max_count_static) {
max_token = token;
max_count_primary = count_primary;
max_count_static = count_static;
}
sum_count_primary += count_primary; sum_count_primary += count_primary;
sum_count_prod += count_primary*count_static;
} }
if (sum_count_primary < min_sample_size[i]) { for (std::pair<llama_token, int> token_count_primary : part_primary) {
continue; const llama_token token = token_count_primary.first;
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
const int32_t count_primary = token_count_primary.second;
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
const int32_t count_prod = count_primary*count_static;
if (sum_count_primary < min_sample_size[i]) {
continue;
}
if (100*count_prod < min_percent[i]*sum_count_prod) {
continue;
}
draft_candidate cc;
for (const llama_token & t : cp.draft) {
cc.draft.push_back(t);
}
cc.draft.push_back(token);
cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod);
cc.nsampled = nsc;
bool duplicate = false;
for (const draft_candidate & co : drafts_new) {
if (co.draft == cc.draft) {
duplicate = true;
break;
}
}
if (duplicate) {
continue;
}
drafts_new.push_back(cc);
} }
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
continue;;
}
drafted_token = max_token;
} }
return drafted_token;
} }
void llama_ngram_cache_draft( void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max, std::vector<llama_token> & inp, std::vector<std::vector<llama_token>> & drafts, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
) { ) {
GGML_ASSERT(draft.size() == 1); if (n_draft == 0) {
const int inp_size = inp.size();
if (inp_size < LLAMA_NGRAM_STATIC) {
return; return;
} }
while ((int) draft.size()-1 < n_draft) { GGML_ASSERT(drafts.size() == 1);
llama_token drafted_token = -1; GGML_ASSERT(drafts[0].size() == 1);
const int inp_size = inp.size();
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; if (inp_size < std::max(ngram_max, LLAMA_NGRAM_STATIC)) {
return;
}
// While building the tree, store drafts with potential children in a heap:
std::vector<draft_candidate> drafts_wip;
{
draft_candidate candidate;
candidate.draft.push_back(drafts[0][0]);
candidate.nll = 0.0f;
candidate.nsampled = LLAMA_NGRAM_MAX;
drafts_wip.push_back(candidate);
}
drafts.clear();
int i_draft = 0;
// Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
std::vector<draft_candidate> drafts_new;
while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) {
for (const draft_candidate & ndc : drafts_new) {
drafts_wip.push_back(ndc);
std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
i_draft++;
}
drafts_new.clear();
std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
const draft_candidate cp = drafts_wip.back(); // cp = candidate parent
drafts_wip.pop_back();
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft.size()-1;
llama_ngram ngram_static; llama_ngram ngram_static;
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) { for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j); ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j);
} }
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
llama_ngram_cache_part part_static; llama_ngram_cache_part part_static;
@ -167,29 +270,37 @@ void llama_ngram_cache_draft(
// cd = context + dynamic // cd = context + dynamic
std::vector<llama_ngram> ngrams_cd; std::vector<llama_ngram> ngrams_cd;
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) { for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1; const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1;
llama_ngram ngram_cd; llama_ngram ngram_cd;
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) { for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j); ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j);
} }
ngrams_cd.push_back(ngram_cd); ngrams_cd.push_back(ngram_cd);
} }
if (drafted_token == -1) {
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_static, ngram_static);
}
if (drafted_token == -1) { try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new);
try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
if (drafts_new.empty()) {
drafts.push_back(cp.draft);
i_draft++;
}
}
for (const draft_candidate & dc : drafts_wip) { // dc = draft child
drafts.push_back(dc.draft);
}
std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate());
for (const draft_candidate & dc : drafts_new) {
drafts.push_back(dc.draft);
i_draft++;
if (i_draft >= n_draft) {
break; break;
} }
LOG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
} }
} }

View file

@ -60,6 +60,7 @@ typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
// n-gram -> empirical distribution of following tokens // n-gram -> empirical distribution of following tokens
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache; typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
typedef std::vector<llama_token> llama_draft_t;
// Update an ngram cache with tokens. // Update an ngram cache with tokens.
// ngram_cache: the cache to modify. // ngram_cache: the cache to modify.
@ -82,7 +83,7 @@ void llama_ngram_cache_update(
// nc_dynamic: ngram cache based on previous user generations. // nc_dynamic: ngram cache based on previous user generations.
// nc_static: ngram cache generated from a large text corpus, used for validation. // nc_static: ngram cache generated from a large text corpus, used for validation.
void llama_ngram_cache_draft( void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max, std::vector<llama_token> & inp, std::vector<llama_draft_t> & drafts, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static); llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
// Save an ngram cache to a file. // Save an ngram cache to a file.

View file

@ -8,6 +8,7 @@
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@ -80,22 +81,42 @@ int main(int argc, char ** argv){
while ((int) pseudo_output.size() < n_ctx) { while ((int) pseudo_output.size() < n_ctx) {
// Simulate drafting and decoding from draft: // Simulate drafting and decoding from draft:
std::vector<llama_token> draft; std::vector<llama_draft_t> drafts;
draft.push_back(pseudo_output.back()); llama_draft_t draft0;
draft0.push_back(pseudo_output.back());
drafts.push_back(draft0);
{ {
const int64_t t_start_draft_us = ggml_time_us(); const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); llama_ngram_cache_draft(
pseudo_output, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
t_draft_us += ggml_time_us() - t_start_draft_us; t_draft_us += ggml_time_us() - t_start_draft_us;
} }
GGML_ASSERT((int) drafts.size() <= n_draft || n_draft <= 0);
n_drafted += draft.size() - 1; // FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
for (int j = 1; j < n_draft + 1; ++j) {
std::set<llama_token> seen_tokens;
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) { for (const llama_draft_t & draft : drafts) {
if (j < (int) draft.size() && seen_tokens.find(draft[j]) == seen_tokens.end()) {
seen_tokens.emplace(draft[j]);
n_drafted++;
}
}
}
for (int j = 1; j < n_draft + 1 && (int) pseudo_output.size() < n_ctx; ++j) {
const llama_token ground_truth = inp_slice[pseudo_output.size()]; const llama_token ground_truth = inp_slice[pseudo_output.size()];
const llama_token drafted = draft[j];
if (ground_truth != drafted) { bool ground_truth_in_drafts = false;
for (const llama_draft_t & draft : drafts) {
if (j < (int) draft.size() && draft[j] == ground_truth) {
ground_truth_in_drafts = true;
break;
}
}
if (!ground_truth_in_drafts) {
break; break;
} }
@ -119,7 +140,7 @@ int main(int argc, char ** argv){
} }
} }
draft.erase(draft.begin()); drafts.clear();
} }
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) { if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {

View file

@ -7,6 +7,7 @@
#include <cstdint> #include <cstdint>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@ -21,6 +22,7 @@ int main(int argc, char ** argv){
// max. number of additional tokens to draft if match is found // max. number of additional tokens to draft if match is found
const int n_draft = params.n_draft; const int n_draft = params.n_draft;
const int n_seq = std::max(n_draft, 1);
const bool dump_kv_cache = params.dump_kv_cache; const bool dump_kv_cache = params.dump_kv_cache;
@ -108,9 +110,12 @@ int main(int argc, char ** argv){
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
std::vector<llama_token> draft; std::vector<llama_draft_t> drafts;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_tgt = llama_batch_init(max_context_size, 0, n_seq);
std::vector<std::vector<int>> sampling_idx_store;
sampling_idx_store.resize(n_seq);
sampling_idx_store[0].push_back(0);
// debug // debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
@ -124,13 +129,11 @@ int main(int argc, char ** argv){
llama_kv_cache_dump_view_seqs(kvc_view, 40); llama_kv_cache_dump_view_seqs(kvc_view, 40);
} }
// print current draft sequence
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
int i_dft = 0; int i_dft = 0;
int seq_best = 0;
while (true) { while (true) {
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, sampling_idx_store[seq_best][i_dft]);
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);
@ -147,24 +150,32 @@ int main(int argc, char ** argv){
++n_predict; ++n_predict;
// check if the target token matches the draft // check if the target token matches the draft
if (i_dft < (int) draft.size() && id == draft[i_dft]) { bool accepted = false;
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); for (int j = 0; j < (int) drafts.size() && !has_eos && !drafts.empty(); ++j) {
++n_accept; if (i_dft + 1 < (int) drafts[j].size() && id == drafts[j][i_dft + 1]) {
++n_past; LOG("draft success: (%d, '%s'), seq_id=%d\n", id, token_str.c_str(), j);
++i_dft; ++n_accept;
inp.push_back(id); ++n_past;
{ ++i_dft;
// Update context ngram cache with the newly accepted token: inp.push_back(id);
const int64_t t_start_draft_us = ggml_time_us(); seq_best = j;
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); {
t_draft_us += ggml_time_us() - t_start_draft_us; // Update context ngram cache with the newly accepted token:
} const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
if (params.use_color) { if (params.use_color) {
// color accepted draft token // color accepted draft token
printf("\033[34m%s\033[0m", token_str.c_str()); printf("\033[34m%s\033[0m", token_str.c_str());
fflush(stdout); fflush(stdout);
}
accepted = true;
break;
} }
}
if (accepted) {
continue; continue;
} }
@ -174,10 +185,10 @@ int main(int argc, char ** argv){
fflush(stdout); fflush(stdout);
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); LOG("sampled: (%d, '%s')\n", id, token_str.c_str());
draft.clear(); drafts.clear();
draft.push_back(id); drafts.push_back({id});
inp.push_back(id); inp.push_back(id);
{ {
// Update context ngram cache with the newly accepted token: // Update context ngram cache with the newly accepted token:
@ -194,29 +205,87 @@ int main(int argc, char ** argv){
// KV cache management // KV cache management
// clean the cache of draft tokens that weren't accepted // clean the cache of draft tokens that weren't accepted
if (seq_best != 0 && i_dft > 0) {
llama_kv_cache_seq_cp(ctx, seq_best, 0, n_past-i_dft, n_past);
}
llama_kv_cache_seq_keep(ctx, 0);
llama_kv_cache_seq_rm(ctx, 0, n_past, -1); llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
llama_batch_clear(batch_tgt); llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); for (int j = 0; j < n_seq; ++j) {
sampling_idx_store[j].clear();
// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
GGML_ASSERT(draft[0] == inp.back());
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (size_t i = 1; i < draft.size(); ++i) {
llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
} }
// Draft already contains a single token sampled from the model:
GGML_ASSERT(drafts.size() == 1);
GGML_ASSERT(drafts[0].size() == 1);
GGML_ASSERT(drafts[0][0] == inp.back());
const int64_t t_start_draft_us = ggml_time_us();
llama_ngram_cache_draft(inp, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (int j = 1; j < (int) drafts.size(); ++j) {
llama_kv_cache_seq_cp(ctx, 0, j, -1, -1);
}
int draft_max = 0;
for (const llama_draft_t & draft : drafts) {
draft_max = std::max(draft_max, (int) draft.size());
}
if (draft_max > 1) {
LOG("drafts:\n");
for (const llama_draft_t & draft : drafts) {
LOG(" - %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
}
}
// FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
for (int i = 0; i < draft_max; ++i) {
std::set<llama_token> seen_tokens;
while (true) {
llama_token current_token = -1;
std::vector<llama_seq_id> current_seq_ids;
for (int j = 0; j < (int) drafts.size(); ++j) {
if (i >= (int) drafts[j].size()) {
continue;
}
if (current_token == -1) {
if (seen_tokens.find(drafts[j][i]) != seen_tokens.end()) {
continue;
}
current_token = drafts[j][i];
seen_tokens.emplace(current_token);
}
if (drafts[j][i] != current_token) {
continue;
}
current_seq_ids.push_back(j);
}
if (current_token == -1) {
break;
}
for (const llama_seq_id & sid : current_seq_ids) {
sampling_idx_store[sid].push_back(batch_tgt.n_tokens);
}
llama_batch_add(batch_tgt, current_token, n_past + i, current_seq_ids, true);
n_drafted++;
}
}
n_drafted--; // 1 out of the added token was sampled;
t_draft_us += ggml_time_us() - t_start_draft_us; t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt); llama_decode(ctx, batch_tgt);
++n_past; ++n_past;
draft.erase(draft.begin());
} }
auto t_dec_end = ggml_time_us(); auto t_dec_end = ggml_time_us();