sampling : fix grammar apply

This commit is contained in:
Georgi Gerganov 2024-09-04 21:48:57 +03:00
parent 8e80a1cf6b
commit 9b950671f4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 6 additions and 5 deletions

View file

@ -132,7 +132,7 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool appl
llama_sampler_accept(gsmpl->smpl, token); llama_sampler_accept(gsmpl->smpl, token);
} }
void gpt_sampler_reset (struct gpt_sampler * gsmpl) { void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
llama_constraint_reset(gsmpl->grmr); llama_constraint_reset(gsmpl->grmr);
llama_sampler_reset(gsmpl->smpl); llama_sampler_reset(gsmpl->smpl);

View file

@ -37,9 +37,6 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// for probabilities to be computed even with temp = 0
params.sparams.n_probs = 16;
// max number of parallel drafting sequences (i.e. tree branches) // max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel; const int n_seq_dft = params.n_parallel;

View file

@ -855,6 +855,8 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
// grammar // grammar
struct llama_constraint_context_grammar { struct llama_constraint_context_grammar {
const struct llama_vocab * vocab;
std::string grammar_str; std::string grammar_str;
std::string grammar_root; std::string grammar_root;
@ -889,7 +891,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
/* .copy = */ [](const struct llama_constraint * cnstr) { /* .copy = */ [](const struct llama_constraint * cnstr) {
const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx;
auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr); auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx;
if (ctx_src->grammar) { if (ctx_src->grammar) {
@ -917,12 +919,14 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
if (grammar_str != nullptr && grammar_str[0] != '\0') { if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = { *ctx = {
/*.vocab = */ &vocab,
/*.grammar_str = */ grammar_str, /*.grammar_str = */ grammar_str,
/*.grammar_root = */ grammar_root, /*.grammar_root = */ grammar_root,
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
}; };
} else { } else {
*ctx = { *ctx = {
/*.vocab = */ &vocab,
/*.grammar_str = */ {}, /*.grammar_str = */ {},
/*.grammar_root = */ {}, /*.grammar_root = */ {},
/*.grammar = */ nullptr, /*.grammar = */ nullptr,