account for possible leading whitespace that will be added by tokenizer

e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
This commit is contained in:
xaedes 2023-09-14 10:48:38 +02:00
parent f627e2fe9c
commit 2c59f7bea3
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1250,6 +1250,10 @@ size_t tokenize_file(
return out_tokens.size(); return out_tokens.size();
} }
// account for possible leading whitespace that will be added by tokenizer
// e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
const int n_max_tokens_overhead = 1;
std::vector<char> buf; std::vector<char> buf;
buf.resize(f.size+1); buf.resize(f.size+1);
@ -1264,7 +1268,7 @@ size_t tokenize_file(
if (sample_start.size() == 0) { if (sample_start.size() == 0) {
// tokenize all data at once // tokenize all data at once
out_tokens.resize(buf.size()); out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false); int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false);
if (n_tokens < 0) { if (n_tokens < 0) {
@ -1358,7 +1362,7 @@ size_t tokenize_file(
// printf("sample: '%s'\n", buf_sample.data()); // printf("sample: '%s'\n", buf_sample.data());
// tokenize the sample // tokenize the sample
tok_sample.resize(sample_size); tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx, int n_tokens = llama_tokenize(lctx,
buf_sample.data(), buf_sample.data(),
tok_sample.data(), tok_sample.data(),