sampling : fix state cloning

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-07 14:38:00 +03:00
parent 0e6d170a50
commit 8a82f388cd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx; const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
return llama_sampler_init_dist(ctx->seed); auto * result = llama_sampler_init_dist(ctx->seed);
// copy the state
{
auto * result_ctx = (llama_sampler_dist *) result->ctx;
result_ctx->rng = ctx->rng;
}
return result;
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx; delete (llama_sampler_dist *) smpl->ctx;
@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
}, },
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
// copy the state
{
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}
return result;
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat *) smpl->ctx; delete (llama_sampler_mirostat *) smpl->ctx;
@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
}, },
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
// copy the state
{
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}
return result;
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat_v2 *) smpl->ctx; delete (llama_sampler_mirostat_v2 *) smpl->ctx;
@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
ctx->grammar = grammar_new; ctx->grammar = grammar_new;
}, },
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx; const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
auto * ctx_dst = (llama_sampler_grammar *) result->ctx; // copy the state
if (ctx_src->grammar) { {
ctx_dst->grammar_str = ctx_src->grammar_str; auto * result_ctx = (llama_sampler_grammar *) result->ctx;
ctx_dst->grammar_root = ctx_src->grammar_root;
ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar); if (ctx->grammar) {
result_ctx->grammar_str = ctx->grammar_str;
result_ctx->grammar_root = ctx->grammar_root;
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
}
} }
return result; return result;
@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
ctx->prev.clear(); ctx->prev.clear();
}, },
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties( auto * result = llama_sampler_init_penalties(
ctx_src->n_vocab, ctx->n_vocab,
ctx_src->special_eos_id, ctx->special_eos_id,
ctx_src->linefeed_id, ctx->linefeed_id,
ctx_src->penalty_last_n, ctx->penalty_last_n,
ctx_src->penalty_repeat, ctx->penalty_repeat,
ctx_src->penalty_freq, ctx->penalty_freq,
ctx_src->penalty_present, ctx->penalty_present,
ctx_src->penalize_nl, ctx->penalize_nl,
ctx_src->ignore_eos); ctx->ignore_eos);
auto * ctx_dst = (llama_sampler_penalties *) result->ctx; // copy the state
ctx_dst->prev = ctx_src->prev; {
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
result_ctx->prev = ctx->prev;
}
return result; return result;
}, },
@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
}, },
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx; delete (llama_sampler_logit_bias *) smpl->ctx;