grammars: cache decoded tokens

This commit is contained in:
ochafik 2024-04-21 15:52:16 +01:00
parent 09c256594d
commit 00c709eb4a
2 changed files with 31 additions and 9 deletions

View file

@ -13051,7 +13051,7 @@ struct llama_grammar * llama_grammar_init(
}
} while (true);
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, {}, {} };
}
void llama_grammar_free(struct llama_grammar * grammar) {
@ -13059,7 +13059,7 @@ void llama_grammar_free(struct llama_grammar * grammar) {
}
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints };
// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
@ -13540,7 +13540,7 @@ void llama_sample_repetition_penalties(
}
}
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
@ -13552,21 +13552,36 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
}
}
if (grammar->token_codepoints.empty()) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
grammar->token_codepoints.resize(n_vocab);
grammar->token_pieces.resize(n_vocab);
for (llama_token id = 0; id < n_vocab; ++id) {
const std::string piece = llama_token_to_piece(ctx, id, false);
grammar->token_pieces[id] = piece;
grammar->token_codepoints[id] = decode_utf8(piece, {0, 0});
}
}
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
if (grammar->partial_utf8.n_remain > 0) {
candidates_decoded.reserve(candidates->size);
}
std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);
const auto & piece = grammar->token_pieces[id];
if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else if (grammar->partial_utf8.n_remain == 0){
const auto & decoded = grammar->token_codepoints.at(id);
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
@ -13763,10 +13778,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}
const std::string piece = llama_token_to_piece(ctx, token, false);
const auto & piece = grammar->token_pieces.at(token);
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto decoded = grammar->partial_utf8.n_remain == 0
? grammar->token_codepoints[token]
: decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {

View file

@ -961,7 +961,7 @@ extern "C" {
LLAMA_API void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
struct llama_grammar * grammar);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@ -1099,6 +1099,11 @@ struct llama_grammar {
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
// caching the token pieces & their decoded codepoints.
std::vector<std::string> token_pieces;
std::vector<std::pair<std::vector<uint32_t>,
llama_partial_utf8>> token_codepoints;
};
struct llama_grammar_candidate {