From e3d8c5ecc92c377cb6de3a1a49daffbe742272b5 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 11 Feb 2024 12:33:23 +0100 Subject: [PATCH] lookup: hashmap, most frequent tokens, abort early --- examples/lookup/lookup.cpp | 186 ++++++++++++++++++++++++++++++------- 1 file changed, 150 insertions(+), 36 deletions(-) diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 18235b8a1..7c569de45 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -7,6 +7,22 @@ #include #include #include +#include + +// Data structures to map n-grams to empirical token probabilities: +typedef std::unordered_map token_hashmap; // token -> number of times token has been seen +typedef std::unordered_map all_token_hashmap; // n-gram -> empirical distribution of following tokens +// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id. +// This way no custom hashing function for the n-grams is needed. + +// Min/max n-gram size to search for in prompt: +constexpr int ngram_min = 1; +constexpr int ngram_max = 4; +static_assert(ngram_max <= sizeof(uint64_t)/2, "A 64 bit integer can only hold information for 4 16 bit tokens."); + +// If sample size or percentage in context are below these thresholds the draft is aborted early: +constexpr float draft_min_sample_size[ngram_max] = { 2, 2, 1, 1}; +constexpr float draft_min_percent[ngram_max] = {66, 50, 50, 50}; int main(int argc, char ** argv){ gpt_params params; @@ -16,9 +32,6 @@ int main(int argc, char ** argv){ } // max/min n-grams size to search for in prompt - const int ngram_max = 4; - const int ngram_min = 1; - // length of the candidate / draft sequence, if match is found const int n_draft = params.n_draft; @@ -38,6 +51,7 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); + GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt const bool add_bos = llama_should_add_bos_token(model); @@ -46,6 +60,55 @@ int main(int argc, char ** argv){ std::vector inp; inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + auto update_hashmaps = [](all_token_hashmap * atcs, const llama_token * inp_data, const int inp_size, const int nnew) -> void { + // atcs = all_token_counts: the hashmaps to modify. + // inp_data: the token sequence on which the hashmaps are based. + // inp_size: the current size of inp_data. + // nnew: how many new tokens have been appended to inp_data since the last call to this function. + // + // In order to get correct results inp_data can ONLY BE APPENDED TO. + // Changes in the middle need a complete rebuild. + for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) { + all_token_hashmap * atc = atcs + ngram_size - ngram_min; + + const int i_start = std::max(inp_size - nnew, ngram_size); + for (int i = i_start; i < inp_size; ++i) { + const int ngram_start = i - ngram_size; + uint64_t ngram = inp_data[ngram_start]; + for (int j = ngram_start; j < ngram_start + ngram_size; ++j) { + const uint64_t ngram_part = inp_data[j]; + ngram <<= 16; + ngram |= ngram_part; + } + const llama_token token = inp_data[i]; + + all_token_hashmap::iterator token_counts_it = atc->find(ngram); + if (token_counts_it == atc->end()) { + token_hashmap token_counts; + token_counts.emplace(token, 1); + atc->emplace(ngram, token_counts); + } else { + token_hashmap::iterator tc_it = token_counts_it->second.find(token); + if (tc_it == token_counts_it->second.end()) { + token_counts_it->second.emplace(token, 1); + } else { + tc_it->second++; + } + } + } + } + }; + + all_token_hashmap all_token_counts[ngram_max-ngram_min+1]; + int64_t t_draft_us = 0; + + { + // Fill up hashmaps with tokens from user input: + const int64_t t_start_draft_us = ggml_time_us(); + update_hashmaps(all_token_counts, inp.data(), inp.size(), inp.size()); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + const int max_context_size = llama_n_ctx(ctx); const int max_tokens_list_size = max_context_size - 4; @@ -75,8 +138,6 @@ int main(int argc, char ** argv){ int n_drafted = 0; int n_accept = 0; - int64_t t_draft_us = 0; - int n_past = inp.size(); bool has_eos = false; @@ -128,6 +189,12 @@ int main(int argc, char ** argv){ ++n_past; ++i_dft; inp.push_back(id); + { + // Update hashmaps with the newly accepted token: + const int64_t t_start_draft_us = ggml_time_us(); + update_hashmaps(all_token_counts, inp.data(), inp.size(), 1); + t_draft_us += ggml_time_us() - t_start_draft_us; + } if (params.use_color) { // color accepted draft token @@ -148,6 +215,12 @@ int main(int argc, char ** argv){ draft.clear(); draft.push_back(id); inp.push_back(id); + { + // Update hashmaps with the newly accepted token: + const int64_t t_start_draft_us = ggml_time_us(); + update_hashmaps(all_token_counts, inp.data(), inp.size(), 1); + t_draft_us += ggml_time_us() - t_start_draft_us; + } break; } @@ -162,44 +235,85 @@ int main(int argc, char ** argv){ llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); - // generate n_pred tokens through prompt lookup - auto prompt_lookup = [&]() -> void { - const int inp_size = inp.size(); - for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){ - const llama_token * ngram = &inp[inp_size - ngram_size]; - - for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) { - bool match = true; - for (int j = 0; j < ngram_size; ++j) { - if (inp[i + j] != ngram[j]) { - match = false; - break; - } - } - - if (match) { - const int startIdx = i + ngram_size; - const int endIdx = startIdx + n_draft; - if (endIdx < inp_size) { - for (int j = startIdx; j < endIdx; ++j) { - LOG(" - draft candidate %d: %d\n", j, inp[j]); - draft.push_back(inp[j]); - llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true); - ++n_drafted; - } - return; - } - } - } - } - return; + auto get_token = [](const std::vector inp, const std::vector draft, const size_t i) -> llama_token { + // Helper function to get a token from the combined, speculative sequence of inp and draft. + return i < inp.size() ? inp[i] : draft[1 + i - inp.size()]; }; + auto prompt_lookup = [&]() -> void { + // Generate up to n_draft additional tokens through prompt lookup. + // The draft is aborted early if there is no suitable token candidate to continue the draft. + // At the beginning of this function the draft already contains a single token sampled from the model. + const int inp_size = inp.size(); + + while ((int) draft.size()-1 < n_draft) { + bool draft_success = false; + for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) { + if (ngram_size > inp_size) { + continue; + } + + all_token_hashmap & atc = all_token_counts[ngram_size - ngram_min]; + + const int ngram_start = inp_size-ngram_size + draft.size()-1; + uint64_t ngram = get_token(inp, draft, ngram_start); + for (int j = ngram_start; j < ngram_start + ngram_size; ++j) { + const uint64_t ngram_part = get_token(inp, draft, j); + ngram <<= 16; + ngram |= ngram_part; + } + + all_token_hashmap::iterator token_counts_it = atc.find(ngram); + if (token_counts_it == atc.end()) { + continue; + } + const token_hashmap token_counts = token_counts_it->second; + + int max_count = 0; + int sum_count = 0; + llama_token max_token = -1; + + for (std::pair tc : token_counts) { + const llama_token token = tc.first; + const llama_token count = tc.second; + + if (count > max_count) { + max_token = token; + max_count = count; + } + sum_count += count; + } + // Skip this candidate if the sample size is too low: + if (sum_count < draft_min_sample_size[ngram_size-1]) { + continue; + } + // skip this candidate if the empirically most likely token following this token is not likely enough: + if (100*max_count < draft_min_percent[ngram_size-1]*sum_count) { + continue; + } + + LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count); + llama_batch_add(batch_tgt, max_token, n_past + draft.size(), { 0 }, true); + draft.push_back(max_token); + draft_success = true; + break; + } + + if (!draft_success) { + break; + } + } + }; + + // Draft already contains a single token sampled from the model: + GGML_ASSERT(draft.size() == 1); + GGML_ASSERT(draft[0] == inp.back()); const int64_t t_start_draft_us = ggml_time_us(); prompt_lookup(); t_draft_us += ggml_time_us() - t_start_draft_us; + n_drafted += draft.size() - 1; llama_decode(ctx, batch_tgt); ++n_past;