diff --git a/llama.cpp b/llama.cpp index 33b948470..704c5e24b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13051,7 +13051,7 @@ struct llama_grammar * llama_grammar_init( } } while (true); - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, {}, {} }; } void llama_grammar_free(struct llama_grammar * grammar) { @@ -13059,7 +13059,7 @@ void llama_grammar_free(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -13540,7 +13540,7 @@ void llama_sample_repetition_penalties( } } -void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_grammar * grammar) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -13552,21 +13552,36 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } + if (grammar->token_codepoints.empty()) { + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); + grammar->token_codepoints.resize(n_vocab); + grammar->token_pieces.resize(n_vocab); + for (llama_token id = 0; id < n_vocab; ++id) { + const std::string piece = llama_token_to_piece(ctx, id, false); + grammar->token_pieces[id] = piece; + grammar->token_codepoints[id] = decode_utf8(piece, {0, 0}); + } + } + std::vector, llama_partial_utf8>> candidates_decoded; - candidates_decoded.reserve(candidates->size); + if (grammar->partial_utf8.n_remain > 0) { + candidates_decoded.reserve(candidates->size); + } std::vector candidates_grammar; candidates_grammar.reserve(candidates->size); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id, false); - + const auto & piece = grammar->token_pieces[id]; if (llama_token_is_eog(&ctx->model, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; + } else if (grammar->partial_utf8.n_remain == 0){ + const auto & decoded = grammar->token_codepoints.at(id); + candidates_grammar.push_back({ i, decoded.first.data(), decoded.second }); } else { candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); @@ -13763,10 +13778,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token, false); + const auto & piece = grammar->token_pieces.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = grammar->partial_utf8.n_remain == 0 + ? grammar->token_codepoints[token] + : decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; std::vector> tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { diff --git a/llama.h b/llama.h index 8b1b15ed4..13c596353 100644 --- a/llama.h +++ b/llama.h @@ -961,7 +961,7 @@ extern "C" { LLAMA_API void llama_sample_grammar( struct llama_context * ctx, llama_token_data_array * candidates, - const struct llama_grammar * grammar); + struct llama_grammar * grammar); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -1099,6 +1099,11 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // caching the token pieces & their decoded codepoints. + std::vector token_pieces; + std::vector, + llama_partial_utf8>> token_codepoints; }; struct llama_grammar_candidate {