initial porting of previous LLG patch

This commit is contained in:
Michal Moskal 2025-01-25 14:43:57 -08:00
parent 4a75d19376
commit 76290d9ea0
5 changed files with 277 additions and 1 deletions

1
.gitignore vendored
View file

@ -143,3 +143,4 @@ poetry.toml
# Local scripts # Local scripts
/run-vim.sh /run-vim.sh
/run-chat.sh /run-chat.sh
include/llguidance.h

View file

@ -79,6 +79,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
# 3rd party libs # 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_LLGUIDANCE "llama: build LLGuidance library for structured output" OFF)
# Required for relocatable CMake package # Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)

View file

@ -991,11 +991,15 @@ public:
}; };
std::string json_schema_to_grammar(const json & schema) { std::string json_schema_to_grammar(const json & schema) {
#ifdef LLAMA_LLGUIDANCE
return "llg:json:" + schema.dump();
#else
return build_grammar([&](const llama_grammar_builder & callbacks) { return build_grammar([&](const llama_grammar_builder & callbacks) {
auto copy = schema; auto copy = schema;
callbacks.resolve_refs(copy); callbacks.resolve_refs(copy);
callbacks.add_schema("", copy); callbacks.add_schema("", copy);
}); });
#endif
} }
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) { std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {

253
common/llguidance.cpp Normal file
View file

@ -0,0 +1,253 @@
#ifdef LLAMA_LLGUIDANCE
#include "llguidance.h"
struct llama_sampler_llg {
const struct llama_model * model;
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer *tokenizer;
LlgConstraint *grammar;
LlgMaskResult llg_res;
bool has_llg_res;
};
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
const char * grammar_kind, const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) {
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c);
return nullptr;
}
return c;
}
static const char * llama_sampler_llg_name(const struct llama_sampler * /*smpl*/) {
return "llguidance";
}
static void llama_sampler_llg_accept_impl(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
LlgCommitResult res;
llg_commit_token(ctx->grammar, token, &res);
ctx->has_llg_res = false;
}
}
static void llama_sampler_llg_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
if (!ctx->has_llg_res) {
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true;
} else {
LLAMA_LOG_ERROR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar);
ctx->grammar = nullptr;
}
}
if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_token_is_eog(ctx->model, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}
} else {
const uint32_t *mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
}
}
}
}
static void llama_sampler_llg_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (!ctx->grammar) {
return;
}
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
llg_free_constraint(ctx->grammar);
ctx->grammar = grammar_new;
ctx->has_llg_res = false;
}
static struct llama_sampler * llama_sampler_llg_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
auto * result = llama_sampler_init_llg(ctx->model, nullptr, nullptr);
// copy the state
{
auto * result_ctx = (llama_sampler_llg *) result->ctx;
if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}
return result;
}
static void llama_sampler_llg_free(struct llama_sampler * smpl) {
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
llg_free_constraint(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}
delete ctx;
}
static struct llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
};
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
const uint8_t *bytes,
size_t bytes_len,
uint32_t *output_tokens,
size_t output_tokens_len)
{
const struct llama_model *model = (const struct llama_model *)user_data;
int r = llama_tokenize(model, (const char *) bytes, bytes_len,
(int32_t*)output_tokens, output_tokens_len, false, true);
if (r < 0)
return -r;
return r;
}
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const struct llama_model * model) {
// TODO store the tokenizer in the model somehow
static const struct llama_model *model_cache;
static LlgTokenizer *tokenizer_cache;
if (model_cache == model) {
return llg_clone_tokenizer(tokenizer_cache);
}
auto tok_eos = llama_token_eot(model);
if (tok_eos == LLAMA_TOKEN_NULL)
tok_eos = llama_token_eos(model);
size_t vocab_size = llama_n_vocab(model);
auto token_lens = new uint32_t[vocab_size];
// we typically have ~7 bytes per token; let's go on the safe side here
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
auto token_bytes = new uint8_t[token_bytes_size];
size_t offset = 0;
for (size_t i = 0; i < vocab_size; i++) {
size_t max_token = 1024;
if (token_bytes_size - offset < max_token) {
GGML_ABORT("token_bytes buffer too small\n");
}
llama_token token = i;
auto dp = (char *) token_bytes + offset;
auto size = llama_detokenize(model, &token, 1, dp, max_token, false, false);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size == 0) {
size = llama_detokenize(model, &token, 1, dp + 1, max_token - 1, false, true);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size != 0) {
*dp = '\xff'; // special token prefix marker
size += 1;
}
}
token_lens[i] = size;
offset += size;
}
LlgTokenizerInit tinit = {
/* .vocab_size = */ (uint32_t)vocab_size,
/* .tok_eos = */ (uint32_t)tok_eos,
/* .token_lens = */ token_lens,
/* .token_bytes = */ token_bytes,
/* .tokenizer_json = */ nullptr,
/* .tokenize_assumes_string = */ false,
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
/* .use_approximate_greedy_tokenize_fn = */ false,
/* .tokenize_user_data = */ model,
};
char error_buffer[1024];
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
delete[] token_bytes;
delete[] token_lens;
if (tokenizer == nullptr) {
LLAMA_LOG_ERROR("llg tokenizer error: %s\n", error_buffer);
return tokenizer;
}
if (tokenizer_cache) {
llg_free_tokenizer(tokenizer_cache);
}
model_cache = model;
tokenizer_cache = tokenizer;
return tokenizer;
}
struct llama_sampler * llama_sampler_init_llg(const struct llama_model * model,
const char * grammar_kind, const char * grammar_data) {
auto * ctx = new llama_sampler_llg;
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(model);
*ctx = {
/* .model = */ model,
/* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
} else {
*ctx = {
/* .model = */ model,
/* .grammar_kind = */ {},
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
}
return new llama_sampler {
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
}
#endif

View file

@ -151,9 +151,26 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
struct llama_sampler * grmr;
if (params.grammar.compare(0, 4, "llg:") == 0) {
#ifdef LLAMA_LLGUIDANCE
auto gp = params.grammar.find(':', 4);
if (gp == std::string::npos) {
GGML_ABORT("invalid serialized grammar");
}
auto grm_type = params.grammar.substr(4, gp - 4);
auto grm_data = params.grammar.c_str() + gp + 1;
grmr = llama_sampler_init_llg(model, grm_type.c_str(), grm_data);
#else
GGML_ABORT("llguidance (LLAMA_LLGUIDANCE cmake parameter) is not enabled");
#endif
} else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), /* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},