format file

This commit is contained in:
Michal Moskal 2025-01-26 10:11:39 -08:00
parent 44e1973af0
commit c9e9853e6c

View file

@ -1,24 +1,24 @@
#ifdef LLAMA_USE_LLGUIDANCE
#include "common.h"
#include "sampling.h"
#include "log.h"
#include "llama.h"
# include "common.h"
# include "sampling.h"
# include "log.h"
# include "llama.h"
#include "llguidance.h"
# include "llguidance.h"
struct llama_sampler_llg {
const llama_vocab * vocab;
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer *tokenizer;
LlgConstraint *grammar;
LlgTokenizer * tokenizer;
LlgConstraint * grammar;
LlgMaskResult llg_res;
bool has_llg_res;
};
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
const char * grammar_kind, const char * grammar_data) {
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
// cinit.log_stderr_level = 2;
@ -64,7 +64,7 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array
}
}
} else {
const uint32_t *mask = ctx->llg_res.sample_mask;
const uint32_t * mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
@ -128,38 +128,35 @@ static llama_sampler_i llama_sampler_llg_i = {
/* .free = */ llama_sampler_llg_free,
};
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
const uint8_t *bytes,
size_t bytes_len,
uint32_t *output_tokens,
size_t output_tokens_len)
{
const llama_vocab *vocab = (const llama_vocab *)user_data;
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
uint32_t * output_tokens, size_t output_tokens_len) {
const llama_vocab * vocab = (const llama_vocab *) user_data;
int r = 0;
try {
r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
(int32_t*)output_tokens, output_tokens_len, false, true);
} catch (const std::exception &e) {
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
true);
} catch (const std::exception & e) {
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
}
if (r < 0)
if (r < 0) {
return -r;
}
return r;
}
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
// TODO store the tokenizer in the vocab somehow
static const llama_vocab *vocab_cache;
static LlgTokenizer *tokenizer_cache;
static const llama_vocab * vocab_cache;
static LlgTokenizer * tokenizer_cache;
if (vocab_cache == vocab) {
return llg_clone_tokenizer(tokenizer_cache);
}
auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL)
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}
size_t vocab_size = llama_vocab_n_tokens(vocab);
@ -196,10 +193,9 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
offset += size;
}
LlgTokenizerInit tinit = {
/* .vocab_size = */ (uint32_t)vocab_size,
/* .tok_eos = */ (uint32_t)tok_eos,
/* .vocab_size = */ (uint32_t) vocab_size,
/* .tok_eos = */ (uint32_t) tok_eos,
/* .token_lens = */ token_lens,
/* .token_bytes = */ token_bytes,
/* .tokenizer_json = */ nullptr,
@ -210,7 +206,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
};
char error_buffer[1024];
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
delete[] token_bytes;
delete[] token_lens;
@ -229,8 +225,8 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
return llg_clone_tokenizer(tokenizer_cache);
}
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data) {
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
const char * grammar_data) {
auto * ctx = new llama_sampler_llg;
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
@ -256,7 +252,7 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
};
}
return new llama_sampler {
return new llama_sampler{
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};