From f29add56d8d8fb8ba8e107a5096b726aeb47adf2 Mon Sep 17 00:00:00 2001 From: marcus Date: Fri, 24 Nov 2023 20:47:01 -0800 Subject: [PATCH] changed allowed saving of pieces to reduce calls to llama_token_to_piece --- common/sampling.cpp | 2 +- llama.cpp | 15 +++++++++++++-- llama.h | 4 +++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2..8a4b1f438 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -166,7 +166,7 @@ llama_token llama_sampling_sample( } if (ctx_sampling->grammar != NULL) { - llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar, nullptr); } if (temp < 0.0) { diff --git a/llama.cpp b/llama.cpp index f2b5967d7..0e87da50a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7106,7 +7106,11 @@ void llama_sample_repetition_penalties( } } -void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { +void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar, + char const * const * pieces) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -7125,7 +7129,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id); + std::string piece; + + if (pieces != nullptr && pieces[id] != nullptr) { + piece = std::string(pieces[id]); + } else { + piece = llama_token_to_piece(ctx, id); + } + if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; diff --git a/llama.h b/llama.h index 1a62058d1..7f2a1a6c6 100644 --- a/llama.h +++ b/llama.h @@ -722,7 +722,9 @@ extern "C" { LLAMA_API void llama_sample_grammar( struct llama_context * ctx, llama_token_data_array * candidates, - const struct llama_grammar * grammar); + const struct llama_grammar * grammar, + char const * const * pieces); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.