pass vocab not model to llama_sampler_init_llg()

This commit is contained in:
Michal Moskal 2025-01-26 08:16:56 -08:00
parent de269a1833
commit a7be6669b1
3 changed files with 10 additions and 17 deletions

View file

@ -8,7 +8,6 @@
#include "llguidance.h" #include "llguidance.h"
struct llama_sampler_llg { struct llama_sampler_llg {
const llama_model * model;
const llama_vocab * vocab; const llama_vocab * vocab;
std::string grammar_kind; std::string grammar_kind;
std::string grammar_data; std::string grammar_data;
@ -91,7 +90,7 @@ static void llama_sampler_llg_reset(llama_sampler * smpl) {
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_llg *) smpl->ctx; const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr); auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
// copy the state // copy the state
{ {
@ -143,17 +142,15 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
return r; return r;
} }
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model) { static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
// TODO store the tokenizer in the model somehow // TODO store the tokenizer in the vocab somehow
static const llama_model *model_cache; static const llama_vocab *vocab_cache;
static LlgTokenizer *tokenizer_cache; static LlgTokenizer *tokenizer_cache;
if (model_cache == model) { if (vocab_cache == vocab) {
return llg_clone_tokenizer(tokenizer_cache); return llg_clone_tokenizer(tokenizer_cache);
} }
const llama_vocab *vocab = llama_model_get_vocab(model);
auto tok_eos = llama_vocab_eot(vocab); auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) if (tok_eos == LLAMA_TOKEN_NULL)
tok_eos = llama_vocab_eos(vocab); tok_eos = llama_vocab_eos(vocab);
@ -220,22 +217,19 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model)
if (tokenizer_cache) { if (tokenizer_cache) {
llg_free_tokenizer(tokenizer_cache); llg_free_tokenizer(tokenizer_cache);
} }
model_cache = model; vocab_cache = vocab;
tokenizer_cache = tokenizer; tokenizer_cache = tokenizer;
return tokenizer; return tokenizer;
} }
llama_sampler * llama_sampler_init_llg(const llama_model * model, llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data) { const char * grammar_kind, const char * grammar_data) {
auto * ctx = new llama_sampler_llg; auto * ctx = new llama_sampler_llg;
const llama_vocab * vocab = llama_model_get_vocab(model);
if (grammar_kind != nullptr && grammar_kind[0] != '\0') { if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(model); auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
*ctx = { *ctx = {
/* .model = */ model,
/* .vocab = */ vocab, /* .vocab = */ vocab,
/* .grammar_kind = */ grammar_kind, /* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data, /* .grammar_data = */ grammar_data,
@ -246,7 +240,6 @@ llama_sampler * llama_sampler_init_llg(const llama_model * model,
}; };
} else { } else {
*ctx = { *ctx = {
/* .model = */ model,
/* .vocab = */ vocab, /* .vocab = */ vocab,
/* .grammar_kind = */ {}, /* .grammar_kind = */ {},
/* .grammar_data = */ {}, /* .grammar_data = */ {},

View file

@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
struct llama_sampler * grmr; struct llama_sampler * grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(model, "lark", params.grammar.c_str()); grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else #else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE

View file

@ -104,6 +104,6 @@ std::vector<enum common_sampler_type> common_sampler_types_from_names(const std:
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars); std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
struct llama_sampler * llama_sampler_init_llg(const llama_model * model, struct llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data); const char * grammar_kind, const char * grammar_data);
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE