diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 757562758..17fa52752 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -1,24 +1,24 @@ #ifdef LLAMA_USE_LLGUIDANCE -#include "common.h" -#include "sampling.h" -#include "log.h" -#include "llama.h" +# include "common.h" +# include "sampling.h" +# include "log.h" +# include "llama.h" -#include "llguidance.h" +# include "llguidance.h" struct llama_sampler_llg { const llama_vocab * vocab; - std::string grammar_kind; - std::string grammar_data; - LlgTokenizer *tokenizer; - LlgConstraint *grammar; - LlgMaskResult llg_res; - bool has_llg_res; + std::string grammar_kind; + std::string grammar_data; + LlgTokenizer * tokenizer; + LlgConstraint * grammar; + LlgMaskResult llg_res; + bool has_llg_res; }; -static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer, - const char * grammar_kind, const char * grammar_data) { +static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind, + const char * grammar_data) { LlgConstraintInit cinit; llg_constraint_init_set_defaults(&cinit, tokenizer); // cinit.log_stderr_level = 2; @@ -64,7 +64,7 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array } } } else { - const uint32_t *mask = ctx->llg_res.sample_mask; + const uint32_t * mask = ctx->llg_res.sample_mask; for (size_t i = 0; i < cur_p->size; ++i) { auto token = cur_p->data[i].id; if ((mask[token / 32] & (1 << (token % 32))) == 0) { @@ -84,7 +84,7 @@ static void llama_sampler_llg_reset(llama_sampler * smpl) { auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str()); llg_free_constraint(ctx->grammar); - ctx->grammar = grammar_new; + ctx->grammar = grammar_new; ctx->has_llg_res = false; } @@ -100,8 +100,8 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) { if (ctx->grammar) { result_ctx->grammar_kind = ctx->grammar_kind; result_ctx->grammar_data = ctx->grammar_data; - result_ctx->grammar = llg_clone_constraint(ctx->grammar); - result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); + result_ctx->grammar = llg_clone_constraint(ctx->grammar); + result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer); } } @@ -128,45 +128,42 @@ static llama_sampler_i llama_sampler_llg_i = { /* .free = */ llama_sampler_llg_free, }; - -static size_t llama_sampler_llg_tokenize_fn(const void *user_data, - const uint8_t *bytes, - size_t bytes_len, - uint32_t *output_tokens, - size_t output_tokens_len) -{ - const llama_vocab *vocab = (const llama_vocab *)user_data; - int r = 0; +static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, + uint32_t * output_tokens, size_t output_tokens_len) { + const llama_vocab * vocab = (const llama_vocab *) user_data; + int r = 0; try { - r = llama_tokenize(vocab, (const char *) bytes, bytes_len, - (int32_t*)output_tokens, output_tokens_len, false, true); - } catch (const std::exception &e) { + r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false, + true); + } catch (const std::exception & e) { GGML_ABORT("llama_tokenize failed: %s\n", e.what()); } - if (r < 0) + if (r < 0) { return -r; + } return r; } -static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { +static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) { // TODO store the tokenizer in the vocab somehow - static const llama_vocab *vocab_cache; - static LlgTokenizer *tokenizer_cache; + static const llama_vocab * vocab_cache; + static LlgTokenizer * tokenizer_cache; if (vocab_cache == vocab) { return llg_clone_tokenizer(tokenizer_cache); } auto tok_eos = llama_vocab_eot(vocab); - if (tok_eos == LLAMA_TOKEN_NULL) + if (tok_eos == LLAMA_TOKEN_NULL) { tok_eos = llama_vocab_eos(vocab); + } size_t vocab_size = llama_vocab_n_tokens(vocab); - auto token_lens = new uint32_t[vocab_size]; + auto token_lens = new uint32_t[vocab_size]; // we typically have ~7 bytes per token; let's go on the safe side here auto token_bytes_size = vocab_size * 16 + 1024 * 1024; - auto token_bytes = new uint8_t[token_bytes_size]; + auto token_bytes = new uint8_t[token_bytes_size]; size_t offset = 0; for (size_t i = 0; i < vocab_size; i++) { @@ -176,8 +173,8 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) } llama_token token = i; - auto dp = (char *) token_bytes + offset; - auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); + auto dp = (char *) token_bytes + offset; + auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false); if (size < 0) { GGML_ABORT("llama_detokenize failed\n"); } @@ -187,7 +184,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) GGML_ABORT("llama_detokenize failed\n"); } if (size != 0) { - *dp = '\xff'; // special token prefix marker + *dp = '\xff'; // special token prefix marker size += 1; } } @@ -196,10 +193,9 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) offset += size; } - LlgTokenizerInit tinit = { - /* .vocab_size = */ (uint32_t)vocab_size, - /* .tok_eos = */ (uint32_t)tok_eos, + /* .vocab_size = */ (uint32_t) vocab_size, + /* .tok_eos = */ (uint32_t) tok_eos, /* .token_lens = */ token_lens, /* .token_bytes = */ token_bytes, /* .tokenizer_json = */ nullptr, @@ -209,8 +205,8 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) /* .tokenize_user_data = */ vocab, }; - char error_buffer[1024]; - LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer)); + char error_buffer[1024]; + LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer)); delete[] token_bytes; delete[] token_lens; @@ -223,19 +219,19 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) if (tokenizer_cache) { llg_free_tokenizer(tokenizer_cache); } - vocab_cache = vocab; + vocab_cache = vocab; tokenizer_cache = tokenizer; return llg_clone_tokenizer(tokenizer_cache); } -llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, - const char * grammar_kind, const char * grammar_data) { +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, + const char * grammar_data) { auto * ctx = new llama_sampler_llg; if (grammar_kind != nullptr && grammar_kind[0] != '\0') { auto tokenizer = llama_sampler_llg_new_tokenizer(vocab); - *ctx = { + *ctx = { /* .vocab = */ vocab, /* .grammar_kind = */ grammar_kind, /* .grammar_data = */ grammar_data, @@ -256,10 +252,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, }; } - return new llama_sampler { + return new llama_sampler{ /* .iface = */ &llama_sampler_llg_i, /* .ctx = */ ctx, }; } -#endif // LLAMA_USE_LLGUIDANCE +#endif // LLAMA_USE_LLGUIDANCE