diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 4b1401ef1..46c2dc4d7 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -991,7 +991,7 @@ public: }; std::string json_schema_to_grammar(const json & schema) { -#ifdef LLAMA_LLGUIDANCE +#ifdef LLAMA_USE_LLGUIDANCE return "llg:json:" + schema.dump(); #else return build_grammar([&](const llama_grammar_builder & callbacks) { diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 4120e93b2..623211747 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -1,8 +1,15 @@ -#ifdef LLAMA_LLGUIDANCE +#ifdef LLAMA_USE_LLGUIDANCE + +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + #include "llguidance.h" struct llama_sampler_llg { const struct llama_model * model; + const struct llama_vocab * vocab; std::string grammar_kind; std::string grammar_data; LlgTokenizer *tokenizer; @@ -17,7 +24,7 @@ static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer, llg_constraint_init_set_defaults(&cinit, tokenizer); auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data); if (llg_get_error(c)) { - LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c)); + LOG_ERR("llg error: %s\n", llg_get_error(c)); llg_free_constraint(c); return nullptr; } @@ -44,7 +51,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) { ctx->has_llg_res = true; } else { - LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar)); + LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar)); llg_free_constraint(ctx->grammar); ctx->grammar = nullptr; } @@ -52,7 +59,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat if (ctx->has_llg_res) { if (ctx->llg_res.is_stop) { for (size_t i = 0; i < cur_p->size; ++i) { - if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) { + if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) { cur_p->data[i].logit = -INFINITY; } } @@ -128,8 +135,8 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data, uint32_t *output_tokens, size_t output_tokens_len) { - const struct llama_model *model = (const struct llama_model *)user_data; - int r = llama_tokenize(model, (const char *) bytes, bytes_len, + const struct llama_vocab *vocab = (const struct llama_vocab *)user_data; + int r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t*)output_tokens, output_tokens_len, false, true); if (r < 0) return -r; @@ -145,11 +152,13 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * return llg_clone_tokenizer(tokenizer_cache); } - auto tok_eos = llama_token_eot(model); - if (tok_eos == LLAMA_TOKEN_NULL) - tok_eos = llama_token_eos(model); + const struct llama_vocab *vocab = llama_model_get_vocab(model); - size_t vocab_size = llama_n_vocab(model); + auto tok_eos = llama_vocab_eot(vocab); + if (tok_eos == LLAMA_TOKEN_NULL) + tok_eos = llama_vocab_eos(vocab); + + size_t vocab_size = llama_vocab_n_tokens(vocab); auto token_lens = new uint32_t[vocab_size]; // we typically have ~7 bytes per token; let's go on the safe side here @@ -165,12 +174,12 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * llama_token token = i; auto dp = (char *) token_bytes + offset; - auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false); + auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); if (size < 0) { GGML_ABORT("llama_detokenize failed\n"); } if (size == 0) { - size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true); + size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true); if (size < 0) { GGML_ABORT("llama_detokenize failed\n"); } @@ -194,7 +203,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * /* .tokenize_assumes_string = */ false, /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, /* .use_approximate_greedy_tokenize_fn = */ false, - /* .tokenize_user_data = */ model, + /* .tokenize_user_data = */ vocab, }; char error_buffer[1024]; @@ -204,7 +213,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * delete[] token_lens; if (tokenizer == nullptr) { - LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer); + LOG_ERR("llg tokenizer error: %s\n", error_buffer); return tokenizer; } @@ -221,10 +230,13 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model, const char * grammar_kind, const char * grammar_data) { auto * ctx = new llama_sampler_llg; + const llama_vocab * vocab = llama_model_get_vocab(model); + if (grammar_kind != nullptr && grammar_kind[0] != '\0') { auto tokenizer = llama_sampler_llg_new_tokenizer(model); *ctx = { /* .model = */ model, + /* .vocab = */ vocab, /* .grammar_kind = */ grammar_kind, /* .grammar_data = */ grammar_data, /* .tokenizer = */ tokenizer, @@ -235,6 +247,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model, } else { *ctx = { /* .model = */ model, + /* .vocab = */ vocab, /* .grammar_kind = */ {}, /* .grammar_data = */ {}, /* .tokenizer = */ nullptr, diff --git a/common/sampling.cpp b/common/sampling.cpp index e12301bc4..97691c04b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -153,7 +153,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co struct llama_sampler * grmr; if (params.grammar.compare(0, 4, "llg:") == 0) { -#ifdef LLAMA_LLGUIDANCE +#ifdef LLAMA_USE_LLGUIDANCE auto gp = params.grammar.find(':', 4); if (gp == std::string::npos) { GGML_ABORT("invalid serialized grammar"); @@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co auto grm_data = params.grammar.c_str() + gp + 1; grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data); #else - GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled"); + GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif } else { grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); diff --git a/common/sampling.h b/common/sampling.h index 348911b18..59baee133 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -102,3 +102,8 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr); std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); + +#ifdef LLAMA_USE_LLGUIDANCE +struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model, + const char * grammar_kind, const char * grammar_data); +#endif \ No newline at end of file