update for new APIs

This commit is contained in:
Michal Moskal 2025-01-25 15:49:07 -08:00
parent 76290d9ea0
commit f19655c4c0
4 changed files with 35 additions and 17 deletions

View file

@ -991,7 +991,7 @@ public:
}; };
std::string json_schema_to_grammar(const json & schema) { std::string json_schema_to_grammar(const json & schema) {
#ifdef LLAMA_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
return "llg:json:" + schema.dump(); return "llg:json:" + schema.dump();
#else #else
return build_grammar([&](const llama_grammar_builder & callbacks) { return build_grammar([&](const llama_grammar_builder & callbacks) {

View file

@ -1,8 +1,15 @@
#ifdef LLAMA_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
#include "common.h"
#include "sampling.h"
#include "log.h"
#include "llama.h"
#include "llguidance.h" #include "llguidance.h"
struct llama_sampler_llg { struct llama_sampler_llg {
const struct llama_model * model; const struct llama_model * model;
const struct llama_vocab * vocab;
std::string grammar_kind; std::string grammar_kind;
std::string grammar_data; std::string grammar_data;
LlgTokenizer *tokenizer; LlgTokenizer *tokenizer;
@ -17,7 +24,7 @@ static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
llg_constraint_init_set_defaults(&cinit, tokenizer); llg_constraint_init_set_defaults(&cinit, tokenizer);
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) { if (llg_get_error(c)) {
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c)); LOG_ERR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c); llg_free_constraint(c);
return nullptr; return nullptr;
} }
@ -44,7 +51,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true; ctx->has_llg_res = true;
} else { } else {
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar)); LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar); llg_free_constraint(ctx->grammar);
ctx->grammar = nullptr; ctx->grammar = nullptr;
} }
@ -52,7 +59,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
if (ctx->has_llg_res) { if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) { if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) { if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY; cur_p->data[i].logit = -INFINITY;
} }
} }
@ -128,8 +135,8 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
uint32_t *output_tokens, uint32_t *output_tokens,
size_t output_tokens_len) size_t output_tokens_len)
{ {
const struct llama_model *model = (const struct llama_model *)user_data; const struct llama_vocab *vocab = (const struct llama_vocab *)user_data;
int r = llama_tokenize(model, (const char *) bytes, bytes_len, int r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
(int32_t*)output_tokens, output_tokens_len, false, true); (int32_t*)output_tokens, output_tokens_len, false, true);
if (r < 0) if (r < 0)
return -r; return -r;
@ -145,11 +152,13 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
return llg_clone_tokenizer(tokenizer_cache); return llg_clone_tokenizer(tokenizer_cache);
} }
auto tok_eos = llama_token_eot(model); const struct llama_vocab *vocab = llama_model_get_vocab(model);
if (tok_eos == LLAMA_TOKEN_NULL)
tok_eos = llama_token_eos(model);
size_t vocab_size = llama_n_vocab(model); auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL)
tok_eos = llama_vocab_eos(vocab);
size_t vocab_size = llama_vocab_n_tokens(vocab);
auto token_lens = new uint32_t[vocab_size]; auto token_lens = new uint32_t[vocab_size];
// we typically have ~7 bytes per token; let's go on the safe side here // we typically have ~7 bytes per token; let's go on the safe side here
@ -165,12 +174,12 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
llama_token token = i; llama_token token = i;
auto dp = (char *) token_bytes + offset; auto dp = (char *) token_bytes + offset;
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false); auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
if (size < 0) { if (size < 0) {
GGML_ABORT("llama_detokenize failed\n"); GGML_ABORT("llama_detokenize failed\n");
} }
if (size == 0) { if (size == 0) {
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true); size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
if (size < 0) { if (size < 0) {
GGML_ABORT("llama_detokenize failed\n"); GGML_ABORT("llama_detokenize failed\n");
} }
@ -194,7 +203,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
/* .tokenize_assumes_string = */ false, /* .tokenize_assumes_string = */ false,
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
/* .use_approximate_greedy_tokenize_fn = */ false, /* .use_approximate_greedy_tokenize_fn = */ false,
/* .tokenize_user_data = */ model, /* .tokenize_user_data = */ vocab,
}; };
char error_buffer[1024]; char error_buffer[1024];
@ -204,7 +213,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
delete[] token_lens; delete[] token_lens;
if (tokenizer == nullptr) { if (tokenizer == nullptr) {
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer); LOG_ERR("llg tokenizer error: %s\n", error_buffer);
return tokenizer; return tokenizer;
} }
@ -221,10 +230,13 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
const char * grammar_kind, const char * grammar_data) { const char * grammar_kind, const char * grammar_data) {
auto * ctx = new llama_sampler_llg; auto * ctx = new llama_sampler_llg;
const llama_vocab * vocab = llama_model_get_vocab(model);
if (grammar_kind != nullptr && grammar_kind[0] != '\0') { if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(model); auto tokenizer = llama_sampler_llg_new_tokenizer(model);
*ctx = { *ctx = {
/* .model = */ model, /* .model = */ model,
/* .vocab = */ vocab,
/* .grammar_kind = */ grammar_kind, /* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data, /* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer, /* .tokenizer = */ tokenizer,
@ -235,6 +247,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
} else { } else {
*ctx = { *ctx = {
/* .model = */ model, /* .model = */ model,
/* .vocab = */ vocab,
/* .grammar_kind = */ {}, /* .grammar_kind = */ {},
/* .grammar_data = */ {}, /* .grammar_data = */ {},
/* .tokenizer = */ nullptr, /* .tokenizer = */ nullptr,

View file

@ -153,7 +153,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
struct llama_sampler * grmr; struct llama_sampler * grmr;
if (params.grammar.compare(0, 4, "llg:") == 0) { if (params.grammar.compare(0, 4, "llg:") == 0) {
#ifdef LLAMA_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
auto gp = params.grammar.find(':', 4); auto gp = params.grammar.find(':', 4);
if (gp == std::string::npos) { if (gp == std::string::npos) {
GGML_ABORT("invalid serialized grammar"); GGML_ABORT("invalid serialized grammar");
@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto grm_data = params.grammar.c_str() + gp + 1; auto grm_data = params.grammar.c_str() + gp + 1;
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data); grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
#else #else
GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif #endif
} else { } else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");

View file

@ -102,3 +102,8 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names); std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars); std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
#ifdef LLAMA_USE_LLGUIDANCE
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
const char * grammar_kind, const char * grammar_data);
#endif