From 6cecefd64c60f0d90730642782a6ecd1e73f2d3c Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 7 Feb 2024 20:38:28 +0100 Subject: [PATCH] works (?), 2.28% accept --- examples/lookup-static/lookup-static.cpp | 34 +++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/examples/lookup-static/lookup-static.cpp b/examples/lookup-static/lookup-static.cpp index d468be067..24939a132 100644 --- a/examples/lookup-static/lookup-static.cpp +++ b/examples/lookup-static/lookup-static.cpp @@ -1,4 +1,5 @@ #include "common.h" +#include "ggml.h" #include "llama.h" #include @@ -64,17 +65,18 @@ int main(int argc, char ** argv){ std::unordered_map> hashmap = {}; for (size_t i = 0; i < inp_static.size()-2; ++i) { - const int64_t key_low = inp_static[i + 0] << 0; - const int64_t key_high = inp_static[i + 1] << 32; - const int64_t value = inp_static[i + 2]; + int64_t key_low = inp_static[i + 0]; + int64_t key_high = inp_static[i + 1]; + key_low <<= 0; + key_high <<= 32; const int64_t key = key_low | key_high; + const llama_token value = inp_static[i + 2]; + auto frequency_it = hashmap.find(key); std::unordered_map frequency; if (frequency_it != hashmap.end()) { frequency = frequency_it->second; - } else { - hashmap.emplace(std::make_pair(key, frequency)); } auto token_it = frequency.find(value); @@ -83,12 +85,17 @@ int main(int argc, char ** argv){ } else { frequency.emplace(std::make_pair(value, 1)); } + + if (frequency_it == hashmap.end()) { + hashmap.emplace(std::make_pair(key, frequency)); + } } printf("\n\n%ld\n\n", hashmap.size()); std::unordered_map hashmap_max; for (auto item : hashmap) { const int64_t key = item.first; const std::unordered_map frequency = item.second; + GGML_ASSERT(!frequency.empty()); llama_token max_token = -1; int max_frequency = 0; @@ -98,6 +105,7 @@ int main(int argc, char ** argv){ max_frequency = item2.second; } } + GGML_ASSERT(max_token != -1); hashmap_max.emplace(std::make_pair(key, max_token)); } @@ -183,6 +191,7 @@ int main(int argc, char ** argv){ ++n_past; ++i_dft; inp.push_back(id); + // fprintf(stderr, "pushed: %d\n", id); if (params.use_color) { // color accepted draft token @@ -203,6 +212,7 @@ int main(int argc, char ** argv){ draft.clear(); draft.push_back(id); inp.push_back(id); + // fprintf(stderr, "pushed: %d\n", id); break; } @@ -220,7 +230,19 @@ int main(int argc, char ** argv){ // generate n_pred tokens through prompt lookup auto prompt_lookup = [&]() -> void { for (int i = 0; i < n_draft; ++i) { - draft.push_back(rand() % 32000); + // fprintf(stderr, "lookup: %d %d\n", inp[inp.size() - 2], inp[inp.size() - 1]); + int64_t key_low = inp[inp.size() - 2]; + int64_t key_high = inp[inp.size() - 1]; + key_low <<= 0; + key_high <<= 32; + const int64_t key = key_low | key_high; + + auto item_it = hashmap_max.find(key); + if (item_it == hashmap_max.end()) { + break; + } + + draft.push_back(item_it->second); ++n_drafted; } return;