format file
This commit is contained in:
parent
44e1973af0
commit
c9e9853e6c
1 changed files with 46 additions and 50 deletions
|
@ -1,24 +1,24 @@
|
|||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
# include "common.h"
|
||||
# include "sampling.h"
|
||||
# include "log.h"
|
||||
# include "llama.h"
|
||||
|
||||
#include "llguidance.h"
|
||||
# include "llguidance.h"
|
||||
|
||||
struct llama_sampler_llg {
|
||||
const llama_vocab * vocab;
|
||||
std::string grammar_kind;
|
||||
std::string grammar_data;
|
||||
LlgTokenizer *tokenizer;
|
||||
LlgConstraint *grammar;
|
||||
LlgMaskResult llg_res;
|
||||
bool has_llg_res;
|
||||
std::string grammar_kind;
|
||||
std::string grammar_data;
|
||||
LlgTokenizer * tokenizer;
|
||||
LlgConstraint * grammar;
|
||||
LlgMaskResult llg_res;
|
||||
bool has_llg_res;
|
||||
};
|
||||
|
||||
static LlgConstraint *llama_sampler_llg_new(LlgTokenizer *tokenizer,
|
||||
const char * grammar_kind, const char * grammar_data) {
|
||||
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
LlgConstraintInit cinit;
|
||||
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
||||
// cinit.log_stderr_level = 2;
|
||||
|
@ -64,7 +64,7 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array
|
|||
}
|
||||
}
|
||||
} else {
|
||||
const uint32_t *mask = ctx->llg_res.sample_mask;
|
||||
const uint32_t * mask = ctx->llg_res.sample_mask;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
auto token = cur_p->data[i].id;
|
||||
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
||||
|
@ -84,7 +84,7 @@ static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
|||
|
||||
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
|
||||
llg_free_constraint(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
ctx->grammar = grammar_new;
|
||||
ctx->has_llg_res = false;
|
||||
}
|
||||
|
||||
|
@ -100,8 +100,8 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
|||
if (ctx->grammar) {
|
||||
result_ctx->grammar_kind = ctx->grammar_kind;
|
||||
result_ctx->grammar_data = ctx->grammar_data;
|
||||
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
||||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
||||
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
||||
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,45 +128,42 @@ static llama_sampler_i llama_sampler_llg_i = {
|
|||
/* .free = */ llama_sampler_llg_free,
|
||||
};
|
||||
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void *user_data,
|
||||
const uint8_t *bytes,
|
||||
size_t bytes_len,
|
||||
uint32_t *output_tokens,
|
||||
size_t output_tokens_len)
|
||||
{
|
||||
const llama_vocab *vocab = (const llama_vocab *)user_data;
|
||||
int r = 0;
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
uint32_t * output_tokens, size_t output_tokens_len) {
|
||||
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
||||
int r = 0;
|
||||
try {
|
||||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len,
|
||||
(int32_t*)output_tokens, output_tokens_len, false, true);
|
||||
} catch (const std::exception &e) {
|
||||
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
||||
true);
|
||||
} catch (const std::exception & e) {
|
||||
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
||||
}
|
||||
if (r < 0)
|
||||
if (r < 0) {
|
||||
return -r;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||
// TODO store the tokenizer in the vocab somehow
|
||||
static const llama_vocab *vocab_cache;
|
||||
static LlgTokenizer *tokenizer_cache;
|
||||
static const llama_vocab * vocab_cache;
|
||||
static LlgTokenizer * tokenizer_cache;
|
||||
|
||||
if (vocab_cache == vocab) {
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
auto tok_eos = llama_vocab_eot(vocab);
|
||||
if (tok_eos == LLAMA_TOKEN_NULL)
|
||||
if (tok_eos == LLAMA_TOKEN_NULL) {
|
||||
tok_eos = llama_vocab_eos(vocab);
|
||||
}
|
||||
|
||||
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
||||
|
||||
auto token_lens = new uint32_t[vocab_size];
|
||||
auto token_lens = new uint32_t[vocab_size];
|
||||
// we typically have ~7 bytes per token; let's go on the safe side here
|
||||
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
||||
auto token_bytes = new uint8_t[token_bytes_size];
|
||||
auto token_bytes = new uint8_t[token_bytes_size];
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < vocab_size; i++) {
|
||||
|
@ -176,8 +173,8 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|||
}
|
||||
|
||||
llama_token token = i;
|
||||
auto dp = (char *) token_bytes + offset;
|
||||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||
auto dp = (char *) token_bytes + offset;
|
||||
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||
if (size < 0) {
|
||||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
|
@ -187,7 +184,7 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|||
GGML_ABORT("llama_detokenize failed\n");
|
||||
}
|
||||
if (size != 0) {
|
||||
*dp = '\xff'; // special token prefix marker
|
||||
*dp = '\xff'; // special token prefix marker
|
||||
size += 1;
|
||||
}
|
||||
}
|
||||
|
@ -196,10 +193,9 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|||
offset += size;
|
||||
}
|
||||
|
||||
|
||||
LlgTokenizerInit tinit = {
|
||||
/* .vocab_size = */ (uint32_t)vocab_size,
|
||||
/* .tok_eos = */ (uint32_t)tok_eos,
|
||||
/* .vocab_size = */ (uint32_t) vocab_size,
|
||||
/* .tok_eos = */ (uint32_t) tok_eos,
|
||||
/* .token_lens = */ token_lens,
|
||||
/* .token_bytes = */ token_bytes,
|
||||
/* .tokenizer_json = */ nullptr,
|
||||
|
@ -209,8 +205,8 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|||
/* .tokenize_user_data = */ vocab,
|
||||
};
|
||||
|
||||
char error_buffer[1024];
|
||||
LlgTokenizer *tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
||||
char error_buffer[1024];
|
||||
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
||||
|
||||
delete[] token_bytes;
|
||||
delete[] token_lens;
|
||||
|
@ -223,19 +219,19 @@ static LlgTokenizer *llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|||
if (tokenizer_cache) {
|
||||
llg_free_tokenizer(tokenizer_cache);
|
||||
}
|
||||
vocab_cache = vocab;
|
||||
vocab_cache = vocab;
|
||||
tokenizer_cache = tokenizer;
|
||||
|
||||
return llg_clone_tokenizer(tokenizer_cache);
|
||||
}
|
||||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||
const char * grammar_kind, const char * grammar_data) {
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
||||
const char * grammar_data) {
|
||||
auto * ctx = new llama_sampler_llg;
|
||||
|
||||
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
||||
*ctx = {
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_kind = */ grammar_kind,
|
||||
/* .grammar_data = */ grammar_data,
|
||||
|
@ -256,10 +252,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
|||
};
|
||||
}
|
||||
|
||||
return new llama_sampler {
|
||||
return new llama_sampler{
|
||||
/* .iface = */ &llama_sampler_llg_i,
|
||||
/* .ctx = */ ctx,
|
||||
};
|
||||
}
|
||||
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue