fix usage of llama_tokenize

This commit is contained in:
xaedes 2023-09-16 20:36:43 +02:00
parent d3e06d3e73
commit 7930caf24c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -853,10 +853,22 @@ size_t tokenize_file(
// tokenize all data at once // tokenize all data at once
out_tokens.resize(buf.size() + n_max_tokens_overhead); out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), (int) out_tokens.size(), false); int n_tokens = llama_tokenize(
lctx,
buf.data(),
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
if (n_tokens < 0) { if (n_tokens < 0) {
out_tokens.resize(-n_tokens); out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, buf.data(), out_tokens.data(), (int) out_tokens.size(), false); n_tokens = llama_tokenize(
lctx,
buf.data(),
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
} }
if (n_tokens >= 0) { if (n_tokens >= 0) {
out_tokens.resize(n_tokens); out_tokens.resize(n_tokens);
@ -948,14 +960,18 @@ size_t tokenize_file(
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead); tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx, int n_tokens = llama_tokenize(lctx,
buf_sample.data(), buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(), tok_sample.data(),
(int) tok_sample.size(), false); (int) tok_sample.size(),
false);
if (n_tokens < 0) { if (n_tokens < 0) {
tok_sample.resize(-n_tokens); tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(lctx, n_tokens = llama_tokenize(lctx,
buf_sample.data(), buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(), tok_sample.data(),
(int) tok_sample.size(), false); (int) tok_sample.size(),
false);
GGML_ASSERT(n_tokens >= 0); GGML_ASSERT(n_tokens >= 0);
} }
GGML_ASSERT(n_tokens <= (int) tok_sample.size()); GGML_ASSERT(n_tokens <= (int) tok_sample.size());