lookup: hashmap, most frequent tokens, abort early

This commit is contained in:
JohannesGaessler 2024-02-11 12:33:23 +01:00
parent 4a46d2b792
commit e3d8c5ecc9

View file

@ -7,6 +7,22 @@
#include <cstdio>
#include <string>
#include <vector>
#include <unordered_map>
// Data structures to map n-grams to empirical token probabilities:
typedef std::unordered_map<llama_token, int> token_hashmap; // token -> number of times token has been seen
typedef std::unordered_map<uint64_t, token_hashmap> all_token_hashmap; // n-gram -> empirical distribution of following tokens
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
// This way no custom hashing function for the n-grams is needed.
// Min/max n-gram size to search for in prompt:
constexpr int ngram_min = 1;
constexpr int ngram_max = 4;
static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold information for 4 16 bit tokens.");
// If sample size or percentage in context are below these thresholds the draft is aborted early:
constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1};
constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50};
int main(int argc, char ** argv){
gpt_params params;
@ -16,9 +32,6 @@ int main(int argc, char ** argv){
}
// max/min n-grams size to search for in prompt
const int ngram_max = 4;
const int ngram_min = 1;
// length of the candidate / draft sequence, if match is found
const int n_draft = params.n_draft;
@ -38,6 +51,7 @@ int main(int argc, char ** argv){
// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
@ -46,6 +60,55 @@ int main(int argc, char ** argv){
std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void {
// atcs = all_token_counts: the hashmaps to modify.
// inp_data: the token sequence on which the hashmaps are based.
// inp_size: the current size of inp_data.
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
//
// In order to get correct results inp_data can ONLY BE APPENDED TO.
// Changes in the middle need a complete rebuild.
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
all_token_hashmap * atc = atcs + ngram_size - ngram_min;
const int i_start = std::max(inp_size - nnew, ngram_size);
for (int i = i_start; i < inp_size; ++i) {
const int ngram_start = i - ngram_size;
uint64_t ngram = inp_data[ngram_start];
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
const uint64_t ngram_part = inp_data[j];
ngram <<= 16;
ngram |= ngram_part;
}
const llama_token token = inp_data[i];
all_token_hashmap::iterator token_counts_it = atc->find(ngram);
if (token_counts_it == atc->end()) {
token_hashmap token_counts;
token_counts.emplace(token, 1);
atc->emplace(ngram, token_counts);
} else {
token_hashmap::iterator tc_it = token_counts_it->second.find(token);
if (tc_it == token_counts_it->second.end()) {
token_counts_it->second.emplace(token, 1);
} else {
tc_it->second++;
}
}
}
}
};
all_token_hashmap all_token_counts[ngram_max-ngram_min+1];
int64_t t_draft_us = 0;
{
// Fill up hashmaps with tokens from user input:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size());
t_draft_us += ggml_time_us() - t_start_draft_us;
}
const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;
@ -75,8 +138,6 @@ int main(int argc, char ** argv){
int n_drafted = 0;
int n_accept = 0;
int64_t t_draft_us = 0;
int n_past = inp.size();
bool has_eos = false;
@ -128,6 +189,12 @@ int main(int argc, char ** argv){
++n_past;
++i_dft;
inp.push_back(id);
{
// Update hashmaps with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
if (params.use_color) {
// color accepted draft token
@ -148,6 +215,12 @@ int main(int argc, char ** argv){
draft.clear();
draft.push_back(id);
inp.push_back(id);
{
// Update hashmaps with the newly accepted token:
const int64_t t_start_draft_us = ggml_time_us();
update_hashmaps(all_token_counts, inp.data(), inp.size(), 1);
t_draft_us += ggml_time_us() - t_start_draft_us;
}
break;
}
@ -162,44 +235,85 @@ int main(int argc, char ** argv){
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
// generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void {
const int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size];
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
bool match = true;
for (int j = 0; j < ngram_size; ++j) {
if (inp[i + j] != ngram[j]) {
match = false;
break;
}
}
if (match) {
const int startIdx = i + ngram_size;
const int endIdx = startIdx + n_draft;
if (endIdx < inp_size) {
for (int j = startIdx; j < endIdx; ++j) {
LOG(" - draft candidate %d: %d\n", j, inp[j]);
draft.push_back(inp[j]);
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
++n_drafted;
}
return;
}
}
}
}
return;
auto get_token = [](const std::vector<llama_token> inp, const std::vector<llama_token> draft, const size_t i) -> llama_token {
// Helper function to get a token from the combined, speculative sequence of inp and draft.
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
};
auto prompt_lookup = [&]() -> void {
// Generate up to n_draft additional tokens through prompt lookup.
// The draft is aborted early if there is no suitable token candidate to continue the draft.
// At the beginning of this function the draft already contains a single token sampled from the model.
const int inp_size = inp.size();
while ((int) draft.size()-1 < n_draft) {
bool draft_success = false;
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
if (ngram_size > inp_size) {
continue;
}
all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min];
const int ngram_start = inp_size-ngram_size + draft.size()-1;
uint64_t ngram = get_token(inp, draft, ngram_start);
for (int j = ngram_start; j < ngram_start + ngram_size; ++j) {
const uint64_t ngram_part = get_token(inp, draft, j);
ngram <<= 16;
ngram |= ngram_part;
}
all_token_hashmap::iterator token_counts_it = atc.find(ngram);
if (token_counts_it == atc.end()) {
continue;
}
const token_hashmap token_counts = token_counts_it->second;
int max_count = 0;
int sum_count = 0;
llama_token max_token = -1;
for (std::pair<llama_token, int> tc : token_counts) {
const llama_token token = tc.first;
const llama_token count = tc.second;
if (count > max_count) {
max_token = token;
max_count = count;
}
sum_count += count;
}
// Skip this candidate if the sample size is too low:
if (sum_count < draft_min_sample_size[ngram_size-1]) {
continue;
}
// skip this candidate if the empirically most likely token following this token is not likely enough:
if (100*max_count < draft_min_percent[ngram_size-1]*sum_count) {
continue;
}
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count);
llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true);
draft.push_back(max_token);
draft_success = true;
break;
}
if (!draft_success) {
break;
}
}
};
// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
GGML_ASSERT(draft[0] == inp.back());
const int64_t t_start_draft_us = ggml_time_us();
prompt_lookup();
t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt);
++n_past;