n_considered configurable
This commit is contained in:
parent
e390b22f57
commit
449585a498
1 changed files with 20 additions and 16 deletions
|
@ -63,13 +63,16 @@ int main(int argc, char ** argv){
|
||||||
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||||
inp_static = ::llama_tokenize(ctx, static_input, add_bos, true);
|
inp_static = ::llama_tokenize(ctx, static_input, add_bos, true);
|
||||||
|
|
||||||
std::unordered_map<int64_t, std::unordered_map<llama_token, int>> hashmap = {};
|
constexpr int n_considered = 2;
|
||||||
for (size_t i = 0; i < inp_static.size()-2; ++i) {
|
|
||||||
int64_t key_low = inp_static[i + 0];
|
std::unordered_map<uint64_t, std::unordered_map<llama_token, int>> hashmap = {};
|
||||||
int64_t key_high = inp_static[i + 1];
|
for (size_t i = 0; i < inp_static.size()-n_considered; ++i) {
|
||||||
key_low <<= 0;
|
uint64_t key = inp_static[i];
|
||||||
key_high <<= 32;
|
for (int j = 1; j < n_considered; ++j) {
|
||||||
const int64_t key = key_low | key_high;
|
uint64_t key_part = inp_static[i + j];
|
||||||
|
key <<= 16;
|
||||||
|
key |= key_part;
|
||||||
|
}
|
||||||
|
|
||||||
const llama_token value = inp_static[i + 2];
|
const llama_token value = inp_static[i + 2];
|
||||||
|
|
||||||
|
@ -90,10 +93,10 @@ int main(int argc, char ** argv){
|
||||||
hashmap.emplace(std::make_pair(key, frequency));
|
hashmap.emplace(std::make_pair(key, frequency));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n%ld\n\n", hashmap.size());
|
// printf("\n\n%ld\n\n", hashmap.size());
|
||||||
std::unordered_map<int64_t, llama_token> hashmap_max;
|
std::unordered_map<uint64_t, llama_token> hashmap_max;
|
||||||
for (auto item : hashmap) {
|
for (auto item : hashmap) {
|
||||||
const int64_t key = item.first;
|
const uint64_t key = item.first;
|
||||||
const std::unordered_map<llama_token, int> frequency = item.second;
|
const std::unordered_map<llama_token, int> frequency = item.second;
|
||||||
GGML_ASSERT(!frequency.empty());
|
GGML_ASSERT(!frequency.empty());
|
||||||
|
|
||||||
|
@ -109,7 +112,7 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
hashmap_max.emplace(std::make_pair(key, max_token));
|
hashmap_max.emplace(std::make_pair(key, max_token));
|
||||||
}
|
}
|
||||||
printf("\n\n%ld\n\n", hashmap_max.size());
|
// printf("\n\n%ld\n\n", hashmap_max.size());
|
||||||
|
|
||||||
const int max_context_size = llama_n_ctx(ctx);
|
const int max_context_size = llama_n_ctx(ctx);
|
||||||
const int max_tokens_list_size = max_context_size - 4;
|
const int max_tokens_list_size = max_context_size - 4;
|
||||||
|
@ -231,11 +234,12 @@ int main(int argc, char ** argv){
|
||||||
auto prompt_lookup = [&]() -> void {
|
auto prompt_lookup = [&]() -> void {
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
// fprintf(stderr, "lookup: %d %d\n", inp[inp.size() - 2], inp[inp.size() - 1]);
|
// fprintf(stderr, "lookup: %d %d\n", inp[inp.size() - 2], inp[inp.size() - 1]);
|
||||||
int64_t key_low = inp[inp.size() - 2];
|
uint64_t key = inp[inp.size() - n_considered];
|
||||||
int64_t key_high = inp[inp.size() - 1];
|
for (int j = 1; j < n_considered; ++j) {
|
||||||
key_low <<= 0;
|
const uint64_t key_part = inp[inp.size() - n_considered + j];
|
||||||
key_high <<= 32;
|
key <<= 16;
|
||||||
const int64_t key = key_low | key_high;
|
key |= key_part;
|
||||||
|
}
|
||||||
|
|
||||||
auto item_it = hashmap_max.find(key);
|
auto item_it = hashmap_max.find(key);
|
||||||
if (item_it == hashmap_max.end()) {
|
if (item_it == hashmap_max.end()) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue