pass vocab not model to llama_sampler_init_llg()
This commit is contained in:
parent
de269a1833
commit
a7be6669b1
3 changed files with 10 additions and 17 deletions
|
@ -8,7 +8,6 @@
|
||||||
#include "llguidance.h"
|
#include "llguidance.h"
|
||||||
|
|
||||||
struct llama_sampler_llg {
|
struct llama_sampler_llg {
|
||||||
const llama_model * model;
|
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
std::string grammar_kind;
|
std::string grammar_kind;
|
||||||
std::string grammar_data;
|
std::string grammar_data;
|
||||||
|
@ -91,7 +90,7 @@ static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
||||||
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
||||||
|
|
||||||
auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr);
|
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
||||||
|
|
||||||
// copy the state
|
// copy the state
|
||||||
{
|
{
|
||||||
|
@ -143,17 +142,15 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model) {
|
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||||
// TODO store the tokenizer in the model somehow
|
// TODO store the tokenizer in the vocab somehow
|
||||||
static const llama_model *model_cache;
|
static const llama_vocab *vocab_cache;
|
||||||
static LlgTokenizer *tokenizer_cache;
|
static LlgTokenizer *tokenizer_cache;
|
||||||
|
|
||||||
if (model_cache == model) {
|
if (vocab_cache == vocab) {
|
||||||
return llg_clone_tokenizer(tokenizer_cache);
|
return llg_clone_tokenizer(tokenizer_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab *vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
auto tok_eos = llama_vocab_eot(vocab);
|
auto tok_eos = llama_vocab_eot(vocab);
|
||||||
if (tok_eos == LLAMA_TOKEN_NULL)
|
if (tok_eos == LLAMA_TOKEN_NULL)
|
||||||
tok_eos = llama_vocab_eos(vocab);
|
tok_eos = llama_vocab_eos(vocab);
|
||||||
|
@ -220,22 +217,19 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_model * model)
|
||||||
if (tokenizer_cache) {
|
if (tokenizer_cache) {
|
||||||
llg_free_tokenizer(tokenizer_cache);
|
llg_free_tokenizer(tokenizer_cache);
|
||||||
}
|
}
|
||||||
model_cache = model;
|
vocab_cache = vocab;
|
||||||
tokenizer_cache = tokenizer;
|
tokenizer_cache = tokenizer;
|
||||||
|
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler * llama_sampler_init_llg(const llama_model * model,
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||||
const char * grammar_kind, const char * grammar_data) {
|
const char * grammar_kind, const char * grammar_data) {
|
||||||
auto * ctx = new llama_sampler_llg;
|
auto * ctx = new llama_sampler_llg;
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
|
||||||
|
|
||||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||||
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
|
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .model = */ model,
|
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_kind = */ grammar_kind,
|
/* .grammar_kind = */ grammar_kind,
|
||||||
/* .grammar_data = */ grammar_data,
|
/* .grammar_data = */ grammar_data,
|
||||||
|
@ -246,7 +240,6 @@ llama_sampler * llama_sampler_init_llg(const llama_model * model,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .model = */ model,
|
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_kind = */ {},
|
/* .grammar_kind = */ {},
|
||||||
/* .grammar_data = */ {},
|
/* .grammar_data = */ {},
|
||||||
|
|
|
@ -154,7 +154,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
struct llama_sampler * grmr;
|
struct llama_sampler * grmr;
|
||||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
grmr = llama_sampler_init_llg(model, "lark", params.grammar.c_str());
|
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||||
#else
|
#else
|
||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
|
|
@ -104,6 +104,6 @@ std::vector<enum common_sampler_type> common_sampler_types_from_names(const std:
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
struct llama_sampler * llama_sampler_init_llg(const llama_model * model,
|
struct llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||||
const char * grammar_kind, const char * grammar_data);
|
const char * grammar_kind, const char * grammar_data);
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
Loading…
Add table
Add a link
Reference in a new issue