sampling : fix grammar apply
This commit is contained in:
parent
8e80a1cf6b
commit
9b950671f4
3 changed files with 6 additions and 5 deletions
|
@ -132,7 +132,7 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool appl
|
||||||
llama_sampler_accept(gsmpl->smpl, token);
|
llama_sampler_accept(gsmpl->smpl, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl) {
|
void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
|
||||||
llama_constraint_reset(gsmpl->grmr);
|
llama_constraint_reset(gsmpl->grmr);
|
||||||
|
|
||||||
llama_sampler_reset(gsmpl->smpl);
|
llama_sampler_reset(gsmpl->smpl);
|
||||||
|
|
|
@ -37,9 +37,6 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// for probabilities to be computed even with temp = 0
|
|
||||||
params.sparams.n_probs = 16;
|
|
||||||
|
|
||||||
// max number of parallel drafting sequences (i.e. tree branches)
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
|
|
|
@ -855,6 +855,8 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
|
||||||
// grammar
|
// grammar
|
||||||
|
|
||||||
struct llama_constraint_context_grammar {
|
struct llama_constraint_context_grammar {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
std::string grammar_str;
|
std::string grammar_str;
|
||||||
std::string grammar_root;
|
std::string grammar_root;
|
||||||
|
|
||||||
|
@ -889,7 +891,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = {
|
||||||
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
/* .copy = */ [](const struct llama_constraint * cnstr) {
|
||||||
const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx;
|
const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx;
|
||||||
|
|
||||||
auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr);
|
auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
|
||||||
|
|
||||||
auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx;
|
auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx;
|
||||||
if (ctx_src->grammar) {
|
if (ctx_src->grammar) {
|
||||||
|
@ -917,12 +919,14 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
|
||||||
|
|
||||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
|
/*.vocab = */ &vocab,
|
||||||
/*.grammar_str = */ grammar_str,
|
/*.grammar_str = */ grammar_str,
|
||||||
/*.grammar_root = */ grammar_root,
|
/*.grammar_root = */ grammar_root,
|
||||||
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
|
/*.vocab = */ &vocab,
|
||||||
/*.grammar_str = */ {},
|
/*.grammar_str = */ {},
|
||||||
/*.grammar_root = */ {},
|
/*.grammar_root = */ {},
|
||||||
/*.grammar = */ nullptr,
|
/*.grammar = */ nullptr,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue