sampling : fix state cloning

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-07 14:38:00 +03:00
parent 0e6d170a50
commit 8a82f388cd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
return llama_sampler_init_dist(ctx->seed);
auto * result = llama_sampler_init_dist(ctx->seed);
// copy the state
{
auto * result_ctx = (llama_sampler_dist *) result->ctx;
result_ctx->rng = ctx->rng;
}
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx;
@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
// copy the state
{
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat *) smpl->ctx;
@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
// copy the state
{
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
result_ctx->mu = ctx->mu;
result_ctx->rng = ctx->rng;
}
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
ctx->grammar = grammar_new;
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx;
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
auto * ctx_dst = (llama_sampler_grammar *) result->ctx;
if (ctx_src->grammar) {
ctx_dst->grammar_str = ctx_src->grammar_str;
ctx_dst->grammar_root = ctx_src->grammar_root;
// copy the state
{
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar);
if (ctx->grammar) {
result_ctx->grammar_str = ctx->grammar_str;
result_ctx->grammar_root = ctx->grammar_root;
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
}
}
return result;
@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
ctx->prev.clear();
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx;
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties(
ctx_src->n_vocab,
ctx_src->special_eos_id,
ctx_src->linefeed_id,
ctx_src->penalty_last_n,
ctx_src->penalty_repeat,
ctx_src->penalty_freq,
ctx_src->penalty_present,
ctx_src->penalize_nl,
ctx_src->ignore_eos);
ctx->n_vocab,
ctx->special_eos_id,
ctx->linefeed_id,
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->penalize_nl,
ctx->ignore_eos);
auto * ctx_dst = (llama_sampler_penalties *) result->ctx;
ctx_dst->prev = ctx_src->prev;
// copy the state
{
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
result_ctx->prev = ctx->prev;
}
return result;
},
@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx;