diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 61a1f4a6b..94410d713 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -8,7 +8,6 @@ #include "llguidance.h" struct llama_sampler_llg { - const llama_model * model; const llama_vocab * vocab; std::string grammar_kind; std::string grammar_data; @@ -91,7 +90,7 @@ static void llama_sampler_llg_reset(llama_sampler * smpl) { static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { const auto * ctx = (const llama_sampler_llg *) smpl->ctx; - auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr); + auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr); // copy the state { @@ -143,17 +142,15 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data, return r; } -static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model) { - // TODO store the tokenizer in the model somehow - static const llama_model *model_cache; +static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { + // TODO store the tokenizer in the vocab somehow + static const llama_vocab *vocab_cache; static LlgTokenizer *tokenizer_cache; - if (model_cache == model) { + if (vocab_cache == vocab) { return llg_clone_tokenizer(tokenizer_cache); } - const llama_vocab *vocab = llama_model_get_vocab(model); - auto tok_eos = llama_vocab_eot(vocab); if (tok_eos == LLAMA_TOKEN_NULL) tok_eos = llama_vocab_eos(vocab); @@ -220,22 +217,19 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model) if (tokenizer_cache) { llg_free_tokenizer(tokenizer_cache); } - model_cache = model; + vocab_cache = vocab; tokenizer_cache = tokenizer; return tokenizer; } -llama_sampler * llama_sampler_init_llg(const llama_model * model, +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data) { auto * ctx = new llama_sampler_llg; - const llama_vocab * vocab = llama_model_get_vocab(model); - if (grammar_kind != nullptr && grammar_kind[0] != '\0') { - auto tokenizer = llama_sampler_llg_new_tokenizer(model); + auto tokenizer = llama_sampler_llg_new_tokenizer(vocab); *ctx = { - /* .model = */ model, /* .vocab = */ vocab, /* .grammar_kind = */ grammar_kind, /* .grammar_data = */ grammar_data, @@ -246,7 +240,6 @@ llama_sampler * llama_sampler_init_llg(const llama_model * model, }; } else { *ctx = { - /* .model = */ model, /* .vocab = */ vocab, /* .grammar_kind = */ {}, /* .grammar_data = */ {}, diff --git a/common/sampling.cpp b/common/sampling.cpp index 309207af4..9661ed419 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co struct llama_sampler * grmr; if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(model, "lark", params.grammar.c_str()); + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE diff --git a/common/sampling.h b/common/sampling.h index ae2be0359..33dfcb1f4 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -104,6 +104,6 @@ std::vector common_sampler_types_from_names(const std: std::vector common_sampler_types_from_chars(const std::string & chars); #ifdef LLAMA_USE_LLGUIDANCE -struct llama_sampler * llama_sampler_init_llg(const llama_model * model, +struct llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); #endif // LLAMA_USE_LLGUIDANCE \ No newline at end of file