From 5a64af6c70b22db374475039be7ccd891111af51 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 30 Jan 2025 14:02:37 +0000 Subject: [PATCH] add llama_sampler_init_grammar_lazy instead of renaming the non-lazy --- common/sampling.cpp | 9 +++++---- include/llama.h | 11 ++++++----- src/llama-sampling.cpp | 41 +++++++++++++++++++++++++++++++---------- 3 files changed, 42 insertions(+), 19 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 20026d2de..bc7e49fdb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -158,10 +158,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_grammar_init(vocab, params.grammar.c_str(), "root", - params.grammar_lazy, - trigger_words.data(), trigger_words.size(), - params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()), + /* .grmr = */ params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, diff --git a/include/llama.h b/include/llama.h index 32a3de051..61907ed40 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1194,17 +1194,18 @@ extern "C" { float tau, float eta); - DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar( + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, - const char * grammar_root), - "use llama_sampler_grammar_init instead"); + const char * grammar_root); - LLAMA_API struct llama_sampler * llama_sampler_grammar_init( + /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, - bool lazy, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 67c921b8b..26974f539 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1433,6 +1433,17 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token } } +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (!ctx->grammar) { @@ -1454,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_grammar_init(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); // copy the state { @@ -1490,15 +1501,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; - -struct llama_sampler * llama_sampler_init_grammar( - const struct llama_vocab * vocab, - const char * grammar_str, - const char * grammar_root) { - return llama_sampler_grammar_init(vocab, grammar_str, grammar_root, false, nullptr, 0, nullptr, 0); -} - -struct llama_sampler * llama_sampler_grammar_init( +static struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, @@ -1531,6 +1534,24 @@ struct llama_sampler * llama_sampler_grammar_init( }; } +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); +} + // penalties struct llama_sampler_penalties {