pass correct max number of tokens to llama_tokenize

This commit is contained in:
xaedes 2023-09-14 03:04:04 +02:00
parent 7f378a7561
commit f627e2fe9c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1266,10 +1266,10 @@ size_t tokenize_file(
// tokenize all data at once // tokenize all data at once
out_tokens.resize(buf.size()); out_tokens.resize(buf.size());
int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), buf.size(), false); int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false);
if (n_tokens < 0) { if (n_tokens < 0) {
out_tokens.resize(-n_tokens); out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), buf.size(), false); n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), out_tokens.size(), false);
} }
if (n_tokens >= 0) { if (n_tokens >= 0) {
out_tokens.resize(n_tokens); out_tokens.resize(n_tokens);
@ -1362,13 +1362,13 @@ size_t tokenize_file(
int n_tokens = llama_tokenize(lctx, int n_tokens = llama_tokenize(lctx,
buf_sample.data(), buf_sample.data(),
tok_sample.data(), tok_sample.data(),
sample_size, false); tok_sample.size(), false);
if (n_tokens < 0) { if (n_tokens < 0) {
tok_sample.resize(-n_tokens); tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, n_tokens = llama_tokenize(lctx,
buf_sample.data(), buf_sample.data(),
tok_sample.data(), tok_sample.data(),
sample_size, false); tok_sample.size(), false);
GGML_ASSERT(n_tokens >= 0); GGML_ASSERT(n_tokens >= 0);
} }
GGML_ASSERT(n_tokens <= tok_sample.size()); GGML_ASSERT(n_tokens <= tok_sample.size());