works (?), 2.28% accept

This commit is contained in:
JohannesGaessler 2024-02-07 20:38:28 +01:00
parent 1574279273
commit 6cecefd64c

View file

@ -1,4 +1,5 @@
#include "common.h" #include "common.h"
#include "ggml.h"
#include "llama.h" #include "llama.h"
#include <cmath> #include <cmath>
@ -64,17 +65,18 @@ int main(int argc, char ** argv){
std::unordered_map<int64_t, std::unordered_map<llama_token, int>> hashmap = {}; std::unordered_map<int64_t, std::unordered_map<llama_token, int>> hashmap = {};
for (size_t i = 0; i < inp_static.size()-2; ++i) { for (size_t i = 0; i < inp_static.size()-2; ++i) {
const int64_t key_low = inp_static[i + 0] << 0; int64_t key_low = inp_static[i + 0];
const int64_t key_high = inp_static[i + 1] << 32; int64_t key_high = inp_static[i + 1];
const int64_t value = inp_static[i + 2]; key_low <<= 0;
key_high <<= 32;
const int64_t key = key_low | key_high; const int64_t key = key_low | key_high;
const llama_token value = inp_static[i + 2];
auto frequency_it = hashmap.find(key); auto frequency_it = hashmap.find(key);
std::unordered_map<llama_token, int> frequency; std::unordered_map<llama_token, int> frequency;
if (frequency_it != hashmap.end()) { if (frequency_it != hashmap.end()) {
frequency = frequency_it->second; frequency = frequency_it->second;
} else {
hashmap.emplace(std::make_pair(key, frequency));
} }
auto token_it = frequency.find(value); auto token_it = frequency.find(value);
@ -83,12 +85,17 @@ int main(int argc, char ** argv){
} else { } else {
frequency.emplace(std::make_pair(value, 1)); frequency.emplace(std::make_pair(value, 1));
} }
if (frequency_it == hashmap.end()) {
hashmap.emplace(std::make_pair(key, frequency));
}
} }
printf("\n\n%ld\n\n", hashmap.size()); printf("\n\n%ld\n\n", hashmap.size());
std::unordered_map<int64_t, llama_token> hashmap_max; std::unordered_map<int64_t, llama_token> hashmap_max;
for (auto item : hashmap) { for (auto item : hashmap) {
const int64_t key = item.first; const int64_t key = item.first;
const std::unordered_map<llama_token, int> frequency = item.second; const std::unordered_map<llama_token, int> frequency = item.second;
GGML_ASSERT(!frequency.empty());
llama_token max_token = -1; llama_token max_token = -1;
int max_frequency = 0; int max_frequency = 0;
@ -98,6 +105,7 @@ int main(int argc, char ** argv){
max_frequency = item2.second; max_frequency = item2.second;
} }
} }
GGML_ASSERT(max_token != -1);
hashmap_max.emplace(std::make_pair(key, max_token)); hashmap_max.emplace(std::make_pair(key, max_token));
} }
@ -183,6 +191,7 @@ int main(int argc, char ** argv){
++n_past; ++n_past;
++i_dft; ++i_dft;
inp.push_back(id); inp.push_back(id);
// fprintf(stderr, "pushed: %d\n", id);
if (params.use_color) { if (params.use_color) {
// color accepted draft token // color accepted draft token
@ -203,6 +212,7 @@ int main(int argc, char ** argv){
draft.clear(); draft.clear();
draft.push_back(id); draft.push_back(id);
inp.push_back(id); inp.push_back(id);
// fprintf(stderr, "pushed: %d\n", id);
break; break;
} }
@ -220,7 +230,19 @@ int main(int argc, char ** argv){
// generate n_pred tokens through prompt lookup // generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void { auto prompt_lookup = [&]() -> void {
for (int i = 0; i < n_draft; ++i) { for (int i = 0; i < n_draft; ++i) {
draft.push_back(rand() % 32000); // fprintf(stderr, "lookup: %d %d\n", inp[inp.size() - 2], inp[inp.size() - 1]);
int64_t key_low = inp[inp.size() - 2];
int64_t key_high = inp[inp.size() - 1];
key_low <<= 0;
key_high <<= 32;
const int64_t key = key_low | key_high;
auto item_it = hashmap_max.find(key);
if (item_it == hashmap_max.end()) {
break;
}
draft.push_back(item_it->second);
++n_drafted; ++n_drafted;
} }
return; return;