From 449585a4982b4e844ed2bba7a70da532ad324387 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 7 Feb 2024 21:26:49 +0100 Subject: [PATCH] n_considered configurable --- examples/lookup-static/lookup-static.cpp | 36 +++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/lookup-static/lookup-static.cpp b/examples/lookup-static/lookup-static.cpp index 237f35e68..e3d56e49c 100644 --- a/examples/lookup-static/lookup-static.cpp +++ b/examples/lookup-static/lookup-static.cpp @@ -63,13 +63,16 @@ int main(int argc, char ** argv){ inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); inp_static = ::llama_tokenize(ctx, static_input, add_bos, true); - std::unordered_map> hashmap = {}; - for (size_t i = 0; i < inp_static.size()-2; ++i) { - int64_t key_low = inp_static[i + 0]; - int64_t key_high = inp_static[i + 1]; - key_low <<= 0; - key_high <<= 32; - const int64_t key = key_low | key_high; + constexpr int n_considered = 2; + + std::unordered_map> hashmap = {}; + for (size_t i = 0; i < inp_static.size()-n_considered; ++i) { + uint64_t key = inp_static[i]; + for (int j = 1; j < n_considered; ++j) { + uint64_t key_part = inp_static[i + j]; + key <<= 16; + key |= key_part; + } const llama_token value = inp_static[i + 2]; @@ -90,10 +93,10 @@ int main(int argc, char ** argv){ hashmap.emplace(std::make_pair(key, frequency)); } } - printf("\n\n%ld\n\n", hashmap.size()); - std::unordered_map hashmap_max; + // printf("\n\n%ld\n\n", hashmap.size()); + std::unordered_map hashmap_max; for (auto item : hashmap) { - const int64_t key = item.first; + const uint64_t key = item.first; const std::unordered_map frequency = item.second; GGML_ASSERT(!frequency.empty()); @@ -109,7 +112,7 @@ int main(int argc, char ** argv){ hashmap_max.emplace(std::make_pair(key, max_token)); } - printf("\n\n%ld\n\n", hashmap_max.size()); + // printf("\n\n%ld\n\n", hashmap_max.size()); const int max_context_size = llama_n_ctx(ctx); const int max_tokens_list_size = max_context_size - 4; @@ -231,11 +234,12 @@ int main(int argc, char ** argv){ auto prompt_lookup = [&]() -> void { for (int i = 0; i < n_draft; ++i) { // fprintf(stderr, "lookup: %d %d\n", inp[inp.size() - 2], inp[inp.size() - 1]); - int64_t key_low = inp[inp.size() - 2]; - int64_t key_high = inp[inp.size() - 1]; - key_low <<= 0; - key_high <<= 32; - const int64_t key = key_low | key_high; + uint64_t key = inp[inp.size() - n_considered]; + for (int j = 1; j < n_considered; ++j) { + const uint64_t key_part = inp[inp.size() - n_considered + j]; + key <<= 16; + key |= key_part; + } auto item_it = hashmap_max.find(key); if (item_it == hashmap_max.end()) {