diff --git a/examples/lookup-static/lookup-static.cpp b/examples/lookup-static/lookup-static.cpp index 2424b4b5f..364ce4e18 100644 --- a/examples/lookup-static/lookup-static.cpp +++ b/examples/lookup-static/lookup-static.cpp @@ -62,17 +62,27 @@ int main(int argc, char ** argv){ inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); inp_static = ::llama_tokenize(ctx, static_input, add_bos, true); - std::unordered_map hashmap = {}; - for (size_t i = 0; i < inp_static.size()-1; ++i) { + std::unordered_map> hashmap = {}; + for (size_t i = 0; i < inp_static.size()-2; ++i) { const int64_t key_low = inp_static[i + 0] << 0; const int64_t key_high = inp_static[i + 1] << 32; + const int64_t value = inp_static[i + 2]; const int64_t key = key_low | key_high; - if (hashmap.count(key) != 0) { - continue; + auto frequency_it = hashmap.find(key); + std::unordered_map frequency; + if (frequency_it != hashmap.end()) { + frequency = frequency_it->second; + } else { + hashmap.emplace(std::make_pair(key, frequency)); } - hashmap.emplace(std::make_pair(key, -1)); + auto token_it = frequency.find(value); + if (token_it != frequency.end()) { + token_it->second++; + } else { + frequency.emplace(std::make_pair(value, 1)); + } } printf("\n\n%ld\n\n", hashmap.size());