count token combinations
This commit is contained in:
parent
6d47013d81
commit
1d6059a5e2
1 changed files with 20 additions and 2 deletions
|
@ -2,15 +2,17 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
int main(int argc, char ** argv){
|
int main(int argc, char ** argv){
|
||||||
const char * static_input_file = "./wikitext-2-raw/wiki.test.raw";
|
const char * static_input_file = "./wikitext-2-raw/wiki.train.raw";
|
||||||
std::ifstream file(static_input_file);
|
std::ifstream file(static_input_file);
|
||||||
if (!file) {
|
if (!file) {
|
||||||
fprintf(stderr, "error: failed to open file '%s'\n", static_input_file);
|
fprintf(stderr, "error: failed to open file '%s'\n", static_input_file);
|
||||||
|
@ -56,7 +58,23 @@ int main(int argc, char ** argv){
|
||||||
LOG("add_bos tgt: %d\n", add_bos);
|
LOG("add_bos tgt: %d\n", add_bos);
|
||||||
|
|
||||||
std::vector<llama_token> inp;
|
std::vector<llama_token> inp;
|
||||||
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
std::vector<llama_token> inp_static;
|
||||||
|
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||||
|
inp_static = ::llama_tokenize(ctx, static_input, add_bos, true);
|
||||||
|
|
||||||
|
std::unordered_map<int64_t, llama_token> hashmap = {};
|
||||||
|
for (size_t i = 0; i < inp_static.size()-1; ++i) {
|
||||||
|
const int64_t key_low = inp_static[i + 0] << 0;
|
||||||
|
const int64_t key_high = inp_static[i + 1] << 32;
|
||||||
|
const int64_t key = key_low | key_high;
|
||||||
|
|
||||||
|
if (hashmap.count(key) != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
hashmap.emplace(std::make_pair(key, -1));
|
||||||
|
}
|
||||||
|
printf("\n\n%ld\n\n", hashmap.size());
|
||||||
|
|
||||||
const int max_context_size = llama_n_ctx(ctx);
|
const int max_context_size = llama_n_ctx(ctx);
|
||||||
const int max_tokens_list_size = max_context_size - 4;
|
const int max_tokens_list_size = max_context_size - 4;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue