Merge 8c9784c65d
into 0ab30f8d82
This commit is contained in:
commit
408c8402b6
4 changed files with 315 additions and 113 deletions
|
@ -52,52 +52,101 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
|
||||||
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
|
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
|
||||||
}
|
}
|
||||||
|
|
||||||
// If sample size or percentage are below these thresholds the draft is aborted early:
|
// Sample size and percentage must meet these thresholds to be added to the draft tree:
|
||||||
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
|
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 1, 1, 1, 1};
|
||||||
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
|
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {20, 20, 10, 10};
|
||||||
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
|
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
|
||||||
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
|
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
|
||||||
|
|
||||||
|
struct draft_candidate {
|
||||||
|
llama_draft_t draft;
|
||||||
|
float nll;
|
||||||
|
int nsampled;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct compare_draft_candidate {
|
||||||
|
bool operator()(const draft_candidate & a, const draft_candidate & b){
|
||||||
|
if (a.nsampled > b.nsampled) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (a.nsampled < b.nsampled) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return a.nll < b.nll;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper function that tries to draft tokens from only the static ngram cache:
|
||||||
|
static void try_draft(
|
||||||
|
llama_ngram_cache & nc_static, const llama_ngram & ngram_static,
|
||||||
|
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
|
||||||
|
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
|
||||||
|
|
||||||
|
const int nsc = (ngram_min + LLAMA_NGRAM_STATIC) - (cp.draft.size() - 1);
|
||||||
|
if (nsc < (ngram_min + LLAMA_NGRAM_STATIC + 1)/2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function that tries to draft a token from only the static ngram cache:
|
|
||||||
static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) {
|
|
||||||
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
||||||
if (part_static_it == nc_static.end()) {
|
if (part_static_it == nc_static.end()) {
|
||||||
return -1;
|
return;
|
||||||
}
|
}
|
||||||
const llama_ngram_cache_part part_static = part_static_it->second;
|
const llama_ngram_cache_part part_static = part_static_it->second;
|
||||||
|
|
||||||
int max_count_static = 0;
|
|
||||||
int sum_count_static = 0;
|
int sum_count_static = 0;
|
||||||
llama_token max_token = -1;
|
|
||||||
|
for (std::pair<llama_token, int> token_count_static : part_static) {
|
||||||
|
const int32_t count_static = token_count_static.second;
|
||||||
|
|
||||||
|
sum_count_static += count_static;
|
||||||
|
}
|
||||||
|
|
||||||
for (std::pair<llama_token, int> token_count_static : part_static) {
|
for (std::pair<llama_token, int> token_count_static : part_static) {
|
||||||
const llama_token token = token_count_static.first;
|
const llama_token token = token_count_static.first;
|
||||||
const int32_t count_static = token_count_static.second;
|
const int32_t count_static = token_count_static.second;
|
||||||
|
|
||||||
if (count_static > max_count_static) {
|
if (sum_count_static < min_sample_size[LLAMA_NGRAM_STATIC-1]) {
|
||||||
max_token = token;
|
continue;
|
||||||
max_count_static = count_static;
|
}
|
||||||
|
if (100*count_static < min_percent[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
|
||||||
|
continue;;
|
||||||
}
|
}
|
||||||
sum_count_static += count_static;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
|
draft_candidate cc;
|
||||||
return -1;
|
for (const llama_token & t : cp.draft) {
|
||||||
|
cc.draft.push_back(t);
|
||||||
|
}
|
||||||
|
cc.draft.push_back(token);
|
||||||
|
cc.nll = cp.nll - logf(1.0f*count_static/sum_count_static);
|
||||||
|
cc.nsampled = nsc;
|
||||||
|
|
||||||
|
bool duplicate = false;
|
||||||
|
for (const draft_candidate & co : drafts_new) {
|
||||||
|
if (co.draft == cc.draft) {
|
||||||
|
duplicate = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (duplicate) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
drafts_new.push_back(cc);
|
||||||
}
|
}
|
||||||
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
return max_token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
|
// Try to draft tokens from primary cache (context/dynamic), validate with static cache:
|
||||||
static llama_token try_draft(
|
static void try_draft(
|
||||||
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
|
llama_ngram_cache & nc_primary, const std::vector<llama_ngram> & ngrams_primary, llama_ngram_cache_part & part_static,
|
||||||
const int * min_sample_size, const int * min_percent) {
|
const int * min_sample_size, const int * min_percent, const draft_candidate & cp,
|
||||||
|
const int ngram_min, std::vector<draft_candidate> & drafts_new) {
|
||||||
|
|
||||||
llama_token drafted_token = -1;
|
for (int i = ngrams_primary.size()-1; i >= 0; --i) {
|
||||||
|
const int nsc = (ngram_min + i) - (cp.draft.size() - 1);
|
||||||
|
if (nsc < (ngram_min + i + 1)/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
|
|
||||||
const llama_ngram ngram_primary = ngrams_primary[i];
|
const llama_ngram ngram_primary = ngrams_primary[i];
|
||||||
|
|
||||||
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
|
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
|
||||||
|
@ -106,10 +155,8 @@ static llama_token try_draft(
|
||||||
}
|
}
|
||||||
const llama_ngram_cache_part part_primary = part_primary_it->second;
|
const llama_ngram_cache_part part_primary = part_primary_it->second;
|
||||||
|
|
||||||
int max_count_primary = 0;
|
|
||||||
int max_count_static = 0;
|
|
||||||
int sum_count_primary = 0;
|
int sum_count_primary = 0;
|
||||||
llama_token max_token = -1;
|
int sum_count_prod = 0;
|
||||||
|
|
||||||
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
||||||
const llama_token token = token_count_primary.first;
|
const llama_token token = token_count_primary.first;
|
||||||
|
@ -119,44 +166,100 @@ static llama_token try_draft(
|
||||||
const int32_t count_primary = token_count_primary.second;
|
const int32_t count_primary = token_count_primary.second;
|
||||||
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
||||||
|
|
||||||
if (count_primary*count_static > max_count_primary*max_count_static) {
|
|
||||||
max_token = token;
|
|
||||||
max_count_primary = count_primary;
|
|
||||||
max_count_static = count_static;
|
|
||||||
}
|
|
||||||
sum_count_primary += count_primary;
|
sum_count_primary += count_primary;
|
||||||
|
sum_count_prod += count_primary*count_static;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sum_count_primary < min_sample_size[i]) {
|
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
||||||
continue;
|
const llama_token token = token_count_primary.first;
|
||||||
|
|
||||||
|
llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
|
||||||
|
|
||||||
|
const int32_t count_primary = token_count_primary.second;
|
||||||
|
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
||||||
|
const int32_t count_prod = count_primary*count_static;
|
||||||
|
|
||||||
|
if (sum_count_primary < min_sample_size[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (100*count_prod < min_percent[i]*sum_count_prod) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
draft_candidate cc;
|
||||||
|
for (const llama_token & t : cp.draft) {
|
||||||
|
cc.draft.push_back(t);
|
||||||
|
}
|
||||||
|
cc.draft.push_back(token);
|
||||||
|
cc.nll = cp.nll - logf(1.0f*count_prod/sum_count_prod);
|
||||||
|
cc.nsampled = nsc;
|
||||||
|
|
||||||
|
bool duplicate = false;
|
||||||
|
for (const draft_candidate & co : drafts_new) {
|
||||||
|
if (co.draft == cc.draft) {
|
||||||
|
duplicate = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (duplicate) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
drafts_new.push_back(cc);
|
||||||
}
|
}
|
||||||
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
|
|
||||||
continue;;
|
|
||||||
}
|
|
||||||
drafted_token = max_token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return drafted_token;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_ngram_cache_draft(
|
void llama_ngram_cache_draft(
|
||||||
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
std::vector<llama_token> & inp, std::vector<std::vector<llama_token>> & drafts, int n_draft, int ngram_min, int ngram_max,
|
||||||
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
|
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
|
||||||
) {
|
) {
|
||||||
GGML_ASSERT(draft.size() == 1);
|
if (n_draft == 0) {
|
||||||
const int inp_size = inp.size();
|
|
||||||
|
|
||||||
if (inp_size < LLAMA_NGRAM_STATIC) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
while ((int) draft.size()-1 < n_draft) {
|
GGML_ASSERT(drafts.size() == 1);
|
||||||
llama_token drafted_token = -1;
|
GGML_ASSERT(drafts[0].size() == 1);
|
||||||
|
const int inp_size = inp.size();
|
||||||
|
|
||||||
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
|
if (inp_size < std::max(ngram_max, LLAMA_NGRAM_STATIC)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// While building the tree, store drafts with potential children in a heap:
|
||||||
|
std::vector<draft_candidate> drafts_wip;
|
||||||
|
|
||||||
|
{
|
||||||
|
draft_candidate candidate;
|
||||||
|
candidate.draft.push_back(drafts[0][0]);
|
||||||
|
candidate.nll = 0.0f;
|
||||||
|
candidate.nsampled = LLAMA_NGRAM_MAX;
|
||||||
|
drafts_wip.push_back(candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
drafts.clear();
|
||||||
|
int i_draft = 0;
|
||||||
|
|
||||||
|
// Temporarily hold new drafts in vector, only add part of them in the last iteration to exactly meet n_draft.
|
||||||
|
std::vector<draft_candidate> drafts_new;
|
||||||
|
|
||||||
|
while (i_draft + ((int) drafts_new.size()) < n_draft && !(drafts_wip.empty() && drafts_new.empty())) {
|
||||||
|
for (const draft_candidate & ndc : drafts_new) {
|
||||||
|
drafts_wip.push_back(ndc);
|
||||||
|
std::push_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
|
||||||
|
i_draft++;
|
||||||
|
}
|
||||||
|
drafts_new.clear();
|
||||||
|
|
||||||
|
std::pop_heap(drafts_wip.begin(), drafts_wip.end(), compare_draft_candidate());
|
||||||
|
const draft_candidate cp = drafts_wip.back(); // cp = candidate parent
|
||||||
|
drafts_wip.pop_back();
|
||||||
|
|
||||||
|
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + cp.draft.size()-1;
|
||||||
llama_ngram ngram_static;
|
llama_ngram ngram_static;
|
||||||
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
|
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
|
||||||
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
|
ngram_static.tokens[j-ngram_start_static] = get_token(inp, cp.draft, j);
|
||||||
}
|
}
|
||||||
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
||||||
llama_ngram_cache_part part_static;
|
llama_ngram_cache_part part_static;
|
||||||
|
@ -167,29 +270,37 @@ void llama_ngram_cache_draft(
|
||||||
// cd = context + dynamic
|
// cd = context + dynamic
|
||||||
std::vector<llama_ngram> ngrams_cd;
|
std::vector<llama_ngram> ngrams_cd;
|
||||||
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
|
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
|
||||||
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
|
const int ngram_start_cd = inp_size-ngram_size_cd + cp.draft.size()-1;
|
||||||
llama_ngram ngram_cd;
|
llama_ngram ngram_cd;
|
||||||
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
|
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
|
||||||
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
|
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, cp.draft, j);
|
||||||
}
|
}
|
||||||
ngrams_cd.push_back(ngram_cd);
|
ngrams_cd.push_back(ngram_cd);
|
||||||
}
|
}
|
||||||
if (drafted_token == -1) {
|
|
||||||
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
|
|
||||||
}
|
|
||||||
if (drafted_token == -1) {
|
|
||||||
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
|
|
||||||
}
|
|
||||||
if (drafted_token == -1) {
|
|
||||||
drafted_token = try_draft(nc_static, ngram_static);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (drafted_token == -1) {
|
try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax, cp, ngram_min, drafts_new);
|
||||||
|
try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_lax, cp, ngram_min, drafts_new);
|
||||||
|
try_draft(nc_static, ngram_static, draft_min_sample_size_strict, draft_min_percent_strict, cp, ngram_min, drafts_new);
|
||||||
|
|
||||||
|
if (drafts_new.empty()) {
|
||||||
|
drafts.push_back(cp.draft);
|
||||||
|
i_draft++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const draft_candidate & dc : drafts_wip) { // dc = draft child
|
||||||
|
drafts.push_back(dc.draft);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(drafts_new.begin(), drafts_new.end(), compare_draft_candidate());
|
||||||
|
|
||||||
|
for (const draft_candidate & dc : drafts_new) {
|
||||||
|
drafts.push_back(dc.draft);
|
||||||
|
i_draft++;
|
||||||
|
|
||||||
|
if (i_draft >= n_draft) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
|
||||||
draft.push_back(drafted_token);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part;
|
||||||
// n-gram -> empirical distribution of following tokens
|
// n-gram -> empirical distribution of following tokens
|
||||||
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
|
typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash_function> llama_ngram_cache;
|
||||||
|
|
||||||
|
typedef std::vector<llama_token> llama_draft_t;
|
||||||
|
|
||||||
// Update an ngram cache with tokens.
|
// Update an ngram cache with tokens.
|
||||||
// ngram_cache: the cache to modify.
|
// ngram_cache: the cache to modify.
|
||||||
|
@ -82,7 +83,7 @@ void llama_ngram_cache_update(
|
||||||
// nc_dynamic: ngram cache based on previous user generations.
|
// nc_dynamic: ngram cache based on previous user generations.
|
||||||
// nc_static: ngram cache generated from a large text corpus, used for validation.
|
// nc_static: ngram cache generated from a large text corpus, used for validation.
|
||||||
void llama_ngram_cache_draft(
|
void llama_ngram_cache_draft(
|
||||||
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
std::vector<llama_token> & inp, std::vector<llama_draft_t> & drafts, int n_draft, int ngram_min, int ngram_max,
|
||||||
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
|
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
|
||||||
|
|
||||||
// Save an ngram cache to a file.
|
// Save an ngram cache to a file.
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
@ -80,22 +81,42 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
while ((int) pseudo_output.size() < n_ctx) {
|
while ((int) pseudo_output.size() < n_ctx) {
|
||||||
// Simulate drafting and decoding from draft:
|
// Simulate drafting and decoding from draft:
|
||||||
std::vector<llama_token> draft;
|
std::vector<llama_draft_t> drafts;
|
||||||
draft.push_back(pseudo_output.back());
|
llama_draft_t draft0;
|
||||||
|
draft0.push_back(pseudo_output.back());
|
||||||
|
drafts.push_back(draft0);
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t t_start_draft_us = ggml_time_us();
|
const int64_t t_start_draft_us = ggml_time_us();
|
||||||
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
llama_ngram_cache_draft(
|
||||||
|
pseudo_output, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
||||||
t_draft_us += ggml_time_us() - t_start_draft_us;
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
||||||
}
|
}
|
||||||
|
GGML_ASSERT((int) drafts.size() <= n_draft || n_draft <= 0);
|
||||||
|
|
||||||
n_drafted += draft.size() - 1;
|
// FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
|
||||||
|
for (int j = 1; j < n_draft + 1; ++j) {
|
||||||
|
std::set<llama_token> seen_tokens;
|
||||||
|
|
||||||
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
|
for (const llama_draft_t & draft : drafts) {
|
||||||
|
if (j < (int) draft.size() && seen_tokens.find(draft[j]) == seen_tokens.end()) {
|
||||||
|
seen_tokens.emplace(draft[j]);
|
||||||
|
n_drafted++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 1; j < n_draft + 1 && (int) pseudo_output.size() < n_ctx; ++j) {
|
||||||
const llama_token ground_truth = inp_slice[pseudo_output.size()];
|
const llama_token ground_truth = inp_slice[pseudo_output.size()];
|
||||||
const llama_token drafted = draft[j];
|
|
||||||
|
|
||||||
if (ground_truth != drafted) {
|
bool ground_truth_in_drafts = false;
|
||||||
|
for (const llama_draft_t & draft : drafts) {
|
||||||
|
if (j < (int) draft.size() && draft[j] == ground_truth) {
|
||||||
|
ground_truth_in_drafts = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ground_truth_in_drafts) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +140,7 @@ int main(int argc, char ** argv){
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
draft.erase(draft.begin());
|
drafts.clear();
|
||||||
|
|
||||||
}
|
}
|
||||||
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
|
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
@ -21,6 +22,7 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
// max. number of additional tokens to draft if match is found
|
// max. number of additional tokens to draft if match is found
|
||||||
const int n_draft = params.n_draft;
|
const int n_draft = params.n_draft;
|
||||||
|
const int n_seq = std::max(n_draft, 1);
|
||||||
|
|
||||||
const bool dump_kv_cache = params.dump_kv_cache;
|
const bool dump_kv_cache = params.dump_kv_cache;
|
||||||
|
|
||||||
|
@ -108,9 +110,12 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
|
||||||
std::vector<llama_token> draft;
|
std::vector<llama_draft_t> drafts;
|
||||||
|
|
||||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_tgt = llama_batch_init(max_context_size, 0, n_seq);
|
||||||
|
std::vector<std::vector<int>> sampling_idx_store;
|
||||||
|
sampling_idx_store.resize(n_seq);
|
||||||
|
sampling_idx_store[0].push_back(0);
|
||||||
|
|
||||||
// debug
|
// debug
|
||||||
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
|
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
|
||||||
|
@ -124,13 +129,11 @@ int main(int argc, char ** argv){
|
||||||
llama_kv_cache_dump_view_seqs(kvc_view, 40);
|
llama_kv_cache_dump_view_seqs(kvc_view, 40);
|
||||||
}
|
}
|
||||||
|
|
||||||
// print current draft sequence
|
|
||||||
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
|
|
||||||
|
|
||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
|
int seq_best = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, sampling_idx_store[seq_best][i_dft]);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
@ -147,24 +150,32 @@ int main(int argc, char ** argv){
|
||||||
++n_predict;
|
++n_predict;
|
||||||
|
|
||||||
// check if the target token matches the draft
|
// check if the target token matches the draft
|
||||||
if (i_dft < (int) draft.size() && id == draft[i_dft]) {
|
bool accepted = false;
|
||||||
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
|
for (int j = 0; j < (int) drafts.size() && !has_eos && !drafts.empty(); ++j) {
|
||||||
++n_accept;
|
if (i_dft + 1 < (int) drafts[j].size() && id == drafts[j][i_dft + 1]) {
|
||||||
++n_past;
|
LOG("draft success: (%d, '%s'), seq_id=%d\n", id, token_str.c_str(), j);
|
||||||
++i_dft;
|
++n_accept;
|
||||||
inp.push_back(id);
|
++n_past;
|
||||||
{
|
++i_dft;
|
||||||
// Update context ngram cache with the newly accepted token:
|
inp.push_back(id);
|
||||||
const int64_t t_start_draft_us = ggml_time_us();
|
seq_best = j;
|
||||||
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
|
{
|
||||||
t_draft_us += ggml_time_us() - t_start_draft_us;
|
// Update context ngram cache with the newly accepted token:
|
||||||
}
|
const int64_t t_start_draft_us = ggml_time_us();
|
||||||
|
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
|
||||||
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
||||||
|
}
|
||||||
|
|
||||||
if (params.use_color) {
|
if (params.use_color) {
|
||||||
// color accepted draft token
|
// color accepted draft token
|
||||||
printf("\033[34m%s\033[0m", token_str.c_str());
|
printf("\033[34m%s\033[0m", token_str.c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
accepted = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (accepted) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,10 +185,10 @@ int main(int argc, char ** argv){
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
|
||||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
LOG("sampled: (%d, '%s')\n", id, token_str.c_str());
|
||||||
|
|
||||||
draft.clear();
|
drafts.clear();
|
||||||
draft.push_back(id);
|
drafts.push_back({id});
|
||||||
inp.push_back(id);
|
inp.push_back(id);
|
||||||
{
|
{
|
||||||
// Update context ngram cache with the newly accepted token:
|
// Update context ngram cache with the newly accepted token:
|
||||||
|
@ -194,29 +205,87 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
// KV cache management
|
// KV cache management
|
||||||
// clean the cache of draft tokens that weren't accepted
|
// clean the cache of draft tokens that weren't accepted
|
||||||
|
if (seq_best != 0 && i_dft > 0) {
|
||||||
|
llama_kv_cache_seq_cp(ctx, seq_best, 0, n_past-i_dft, n_past);
|
||||||
|
}
|
||||||
|
llama_kv_cache_seq_keep(ctx, 0);
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||||
|
|
||||||
llama_batch_clear(batch_tgt);
|
llama_batch_clear(batch_tgt);
|
||||||
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
for (int j = 0; j < n_seq; ++j) {
|
||||||
|
sampling_idx_store[j].clear();
|
||||||
// Draft already contains a single token sampled from the model:
|
|
||||||
GGML_ASSERT(draft.size() == 1);
|
|
||||||
GGML_ASSERT(draft[0] == inp.back());
|
|
||||||
const int64_t t_start_draft_us = ggml_time_us();
|
|
||||||
|
|
||||||
llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
|
||||||
|
|
||||||
for (size_t i = 1; i < draft.size(); ++i) {
|
|
||||||
llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Draft already contains a single token sampled from the model:
|
||||||
|
GGML_ASSERT(drafts.size() == 1);
|
||||||
|
GGML_ASSERT(drafts[0].size() == 1);
|
||||||
|
GGML_ASSERT(drafts[0][0] == inp.back());
|
||||||
|
const int64_t t_start_draft_us = ggml_time_us();
|
||||||
|
|
||||||
|
llama_ngram_cache_draft(inp, drafts, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
||||||
|
|
||||||
|
for (int j = 1; j < (int) drafts.size(); ++j) {
|
||||||
|
llama_kv_cache_seq_cp(ctx, 0, j, -1, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int draft_max = 0;
|
||||||
|
for (const llama_draft_t & draft : drafts) {
|
||||||
|
draft_max = std::max(draft_max, (int) draft.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (draft_max > 1) {
|
||||||
|
LOG("drafts:\n");
|
||||||
|
for (const llama_draft_t & draft : drafts) {
|
||||||
|
LOG(" - %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME wrong KV mask for converging sequences (does not seem to happen in practice).
|
||||||
|
for (int i = 0; i < draft_max; ++i) {
|
||||||
|
std::set<llama_token> seen_tokens;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
llama_token current_token = -1;
|
||||||
|
std::vector<llama_seq_id> current_seq_ids;
|
||||||
|
|
||||||
|
for (int j = 0; j < (int) drafts.size(); ++j) {
|
||||||
|
if (i >= (int) drafts[j].size()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_token == -1) {
|
||||||
|
if (seen_tokens.find(drafts[j][i]) != seen_tokens.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
current_token = drafts[j][i];
|
||||||
|
seen_tokens.emplace(current_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drafts[j][i] != current_token) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
current_seq_ids.push_back(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_token == -1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const llama_seq_id & sid : current_seq_ids) {
|
||||||
|
sampling_idx_store[sid].push_back(batch_tgt.n_tokens);
|
||||||
|
}
|
||||||
|
llama_batch_add(batch_tgt, current_token, n_past + i, current_seq_ids, true);
|
||||||
|
n_drafted++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n_drafted--; // 1 out of the added token was sampled;
|
||||||
|
|
||||||
t_draft_us += ggml_time_us() - t_start_draft_us;
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
||||||
n_drafted += draft.size() - 1;
|
|
||||||
|
|
||||||
llama_decode(ctx, batch_tgt);
|
llama_decode(ctx, batch_tgt);
|
||||||
++n_past;
|
++n_past;
|
||||||
|
|
||||||
draft.erase(draft.begin());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto t_dec_end = ggml_time_us();
|
auto t_dec_end = ggml_time_us();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue