grammars: cache codepoints in llama_new_context_with_model

This commit is contained in:
ochafik 2024-04-28 15:24:35 +01:00
parent d41f314740
commit 49f0faaa0e

View file

@ -2328,7 +2328,7 @@ struct llama_context {
// caching token pieces & their decoded codepoints. // caching token pieces & their decoded codepoints.
std::vector<std::string> token_pieces; std::vector<std::string> token_pieces;
std::vector<std::pair<std::vector<uint32_t>, std::vector<std::pair<std::vector<uint32_t>,
llama_partial_utf8>> token_codepoints; llama_partial_utf8>> token_codepoints_without_partial_utf8_prefix;
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL; ggml_mpi_context * ctx_mpi = NULL;
@ -13557,17 +13557,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} }
} }
if (ctx->token_codepoints.empty()) { // Store decoded codepoints when they are not cached.
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
ctx->token_codepoints.resize(n_vocab);
ctx->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
ctx->token_pieces[id] = piece;
ctx->token_codepoints[id] = decode_utf8(piece, {0, 0});
}
}
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
if (grammar->partial_utf8.n_remain > 0) { if (grammar->partial_utf8.n_remain > 0) {
candidates_decoded.reserve(candidates->size); candidates_decoded.reserve(candidates->size);
@ -13585,7 +13575,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (piece.empty() || piece[0] == 0) { } else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} else if (grammar->partial_utf8.n_remain == 0){ } else if (grammar->partial_utf8.n_remain == 0){
const auto & decoded = ctx->token_codepoints.at(id); const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id);
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second }); candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
} else { } else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
@ -13787,7 +13777,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = grammar->partial_utf8.n_remain == 0 const auto decoded = grammar->partial_utf8.n_remain == 0
? ctx->token_codepoints[token] ? ctx->token_codepoints_without_partial_utf8_prefix[token]
: decode_utf8(piece, grammar->partial_utf8); : decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks; std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
@ -15645,6 +15635,18 @@ struct llama_context * llama_new_context_with_model(
} }
} }
// cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar constrained sampling.
{
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab);
ctx->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
ctx->token_pieces[id] = piece;
ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0});
}
}
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
ctx->ctx_mpi = ggml_mpi_init(); ctx->ctx_mpi = ggml_mpi_init();