update for new APIs
This commit is contained in:
parent
76290d9ea0
commit
f19655c4c0
4 changed files with 35 additions and 17 deletions
|
@ -991,7 +991,7 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema) {
|
std::string json_schema_to_grammar(const json & schema) {
|
||||||
#ifdef LLAMA_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
return "llg:json:" + schema.dump();
|
return "llg:json:" + schema.dump();
|
||||||
#else
|
#else
|
||||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
#ifdef LLAMA_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
#include "llguidance.h"
|
#include "llguidance.h"
|
||||||
|
|
||||||
struct llama_sampler_llg {
|
struct llama_sampler_llg {
|
||||||
const struct llama_model * model;
|
const struct llama_model * model;
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
std::string grammar_kind;
|
std::string grammar_kind;
|
||||||
std::string grammar_data;
|
std::string grammar_data;
|
||||||
LlgTokenizer *tokenizer;
|
LlgTokenizer *tokenizer;
|
||||||
|
@ -17,7 +24,7 @@ static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
|
||||||
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
||||||
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
||||||
if (llg_get_error(c)) {
|
if (llg_get_error(c)) {
|
||||||
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c));
|
LOG_ERR("llg error: %s\n", llg_get_error(c));
|
||||||
llg_free_constraint(c);
|
llg_free_constraint(c);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -44,7 +51,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
|
||||||
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
||||||
ctx->has_llg_res = true;
|
ctx->has_llg_res = true;
|
||||||
} else {
|
} else {
|
||||||
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar));
|
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
|
||||||
llg_free_constraint(ctx->grammar);
|
llg_free_constraint(ctx->grammar);
|
||||||
ctx->grammar = nullptr;
|
ctx->grammar = nullptr;
|
||||||
}
|
}
|
||||||
|
@ -52,7 +59,7 @@ static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_dat
|
||||||
if (ctx->has_llg_res) {
|
if (ctx->has_llg_res) {
|
||||||
if (ctx->llg_res.is_stop) {
|
if (ctx->llg_res.is_stop) {
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) {
|
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
|
||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,8 +135,8 @@ static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
|
||||||
uint32_t *output_tokens,
|
uint32_t *output_tokens,
|
||||||
size_t output_tokens_len)
|
size_t output_tokens_len)
|
||||||
{
|
{
|
||||||
const struct llama_model *model = (const struct llama_model *)user_data;
|
const struct llama_vocab *vocab = (const struct llama_vocab *)user_data;
|
||||||
int r = llama_tokenize(model, (const char *) bytes, bytes_len,
|
int r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
|
||||||
(int32_t*)output_tokens, output_tokens_len, false, true);
|
(int32_t*)output_tokens, output_tokens_len, false, true);
|
||||||
if (r < 0)
|
if (r < 0)
|
||||||
return -r;
|
return -r;
|
||||||
|
@ -145,11 +152,13 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
|
||||||
return llg_clone_tokenizer(tokenizer_cache);
|
return llg_clone_tokenizer(tokenizer_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tok_eos = llama_token_eot(model);
|
const struct llama_vocab *vocab = llama_model_get_vocab(model);
|
||||||
if (tok_eos == LLAMA_TOKEN_NULL)
|
|
||||||
tok_eos = llama_token_eos(model);
|
|
||||||
|
|
||||||
size_t vocab_size = llama_n_vocab(model);
|
auto tok_eos = llama_vocab_eot(vocab);
|
||||||
|
if (tok_eos == LLAMA_TOKEN_NULL)
|
||||||
|
tok_eos = llama_vocab_eos(vocab);
|
||||||
|
|
||||||
|
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
auto token_lens = new uint32_t[vocab_size];
|
auto token_lens = new uint32_t[vocab_size];
|
||||||
// we typically have ~7 bytes per token; let's go on the safe side here
|
// we typically have ~7 bytes per token; let's go on the safe side here
|
||||||
|
@ -165,12 +174,12 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
|
||||||
|
|
||||||
llama_token token = i;
|
llama_token token = i;
|
||||||
auto dp = (char *) token_bytes + offset;
|
auto dp = (char *) token_bytes + offset;
|
||||||
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false);
|
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||||
if (size < 0) {
|
if (size < 0) {
|
||||||
GGML_ABORT("llama_detokenize failed\n");
|
GGML_ABORT("llama_detokenize failed\n");
|
||||||
}
|
}
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true);
|
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
||||||
if (size < 0) {
|
if (size < 0) {
|
||||||
GGML_ABORT("llama_detokenize failed\n");
|
GGML_ABORT("llama_detokenize failed\n");
|
||||||
}
|
}
|
||||||
|
@ -194,7 +203,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
|
||||||
/* .tokenize_assumes_string = */ false,
|
/* .tokenize_assumes_string = */ false,
|
||||||
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
||||||
/* .use_approximate_greedy_tokenize_fn = */ false,
|
/* .use_approximate_greedy_tokenize_fn = */ false,
|
||||||
/* .tokenize_user_data = */ model,
|
/* .tokenize_user_data = */ vocab,
|
||||||
};
|
};
|
||||||
|
|
||||||
char error_buffer[1024];
|
char error_buffer[1024];
|
||||||
|
@ -204,7 +213,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model *
|
||||||
delete[] token_lens;
|
delete[] token_lens;
|
||||||
|
|
||||||
if (tokenizer == nullptr) {
|
if (tokenizer == nullptr) {
|
||||||
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer);
|
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,10 +230,13 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
|
||||||
const char * grammar_kind, const char * grammar_data) {
|
const char * grammar_kind, const char * grammar_data) {
|
||||||
auto * ctx = new llama_sampler_llg;
|
auto * ctx = new llama_sampler_llg;
|
||||||
|
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||||
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
|
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .model = */ model,
|
/* .model = */ model,
|
||||||
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_kind = */ grammar_kind,
|
/* .grammar_kind = */ grammar_kind,
|
||||||
/* .grammar_data = */ grammar_data,
|
/* .grammar_data = */ grammar_data,
|
||||||
/* .tokenizer = */ tokenizer,
|
/* .tokenizer = */ tokenizer,
|
||||||
|
@ -235,6 +247,7 @@ struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
|
||||||
} else {
|
} else {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .model = */ model,
|
/* .model = */ model,
|
||||||
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_kind = */ {},
|
/* .grammar_kind = */ {},
|
||||||
/* .grammar_data = */ {},
|
/* .grammar_data = */ {},
|
||||||
/* .tokenizer = */ nullptr,
|
/* .tokenizer = */ nullptr,
|
||||||
|
|
|
@ -153,7 +153,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
struct llama_sampler * grmr;
|
struct llama_sampler * grmr;
|
||||||
if (params.grammar.compare(0, 4, "llg:") == 0) {
|
if (params.grammar.compare(0, 4, "llg:") == 0) {
|
||||||
#ifdef LLAMA_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
auto gp = params.grammar.find(':', 4);
|
auto gp = params.grammar.find(':', 4);
|
||||||
if (gp == std::string::npos) {
|
if (gp == std::string::npos) {
|
||||||
GGML_ABORT("invalid serialized grammar");
|
GGML_ABORT("invalid serialized grammar");
|
||||||
|
@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
auto grm_data = params.grammar.c_str() + gp + 1;
|
auto grm_data = params.grammar.c_str() + gp + 1;
|
||||||
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
|
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
|
||||||
#else
|
#else
|
||||||
GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
|
|
|
@ -102,3 +102,8 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
|
||||||
|
const char * grammar_kind, const char * grammar_data);
|
||||||
|
#endif
|
Loading…
Add table
Add a link
Reference in a new issue