diff --git a/llama.cpp b/llama.cpp index 0a44cf668..31e7cbd9a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2328,7 +2328,7 @@ struct llama_context { // caching token pieces & their decoded codepoints. std::vector token_pieces; std::vector, - llama_partial_utf8>> token_codepoints; + llama_partial_utf8>> token_codepoints_without_partial_utf8_prefix; #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -13557,17 +13557,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } - if (ctx->token_codepoints.empty()) { - auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - ctx->token_codepoints.resize(n_vocab); - ctx->token_pieces.resize(n_vocab); - for (llama_token id = 0; id < n_vocab; ++id) { - const std::string piece = llama_token_to_piece(ctx, id, false); - ctx->token_pieces[id] = piece; - ctx->token_codepoints[id] = decode_utf8(piece, {0, 0}); - } - } - + // Store decoded codepoints when they are not cached. std::vector, llama_partial_utf8>> candidates_decoded; if (grammar->partial_utf8.n_remain > 0) { candidates_decoded.reserve(candidates->size); @@ -13585,7 +13575,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else if (grammar->partial_utf8.n_remain == 0){ - const auto & decoded = ctx->token_codepoints.at(id); + const auto & decoded = ctx->token_codepoints_without_partial_utf8_prefix.at(id); candidates_grammar.push_back({ i, decoded.first.data(), decoded.second }); } else { candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); @@ -13787,7 +13777,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar // Note terminating 0 in decoded string const auto decoded = grammar->partial_utf8.n_remain == 0 - ? ctx->token_codepoints[token] + ? ctx->token_codepoints_without_partial_utf8_prefix[token] : decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; std::vector> tmp_new_stacks; @@ -15645,6 +15635,18 @@ struct llama_context * llama_new_context_with_model( } } + // cache tokens & their decoded codepoints (for common case where there's no partial utf8 prefix bytes) for grammar constrained sampling. + { + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); + ctx->token_codepoints_without_partial_utf8_prefix.resize(n_vocab); + ctx->token_pieces.resize(n_vocab); + for (llama_token id = 0; id < n_vocab; ++id) { + const std::string piece = llama_token_to_piece(ctx, id, false); + ctx->token_pieces[id] = piece; + ctx->token_codepoints_without_partial_utf8_prefix[id] = decode_utf8(piece, {0, 0}); + } + } + #ifdef GGML_USE_MPI ctx->ctx_mpi = ggml_mpi_init();