sampling : fix state cloning
ggml-ci
This commit is contained in:
parent
0e6d170a50
commit
8a82f388cd
1 changed files with 62 additions and 24 deletions
|
@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
||||
return llama_sampler_init_dist(ctx->seed);
|
||||
auto * result = llama_sampler_init_dist(ctx->seed);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_dist *) result->ctx;
|
||||
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_dist *) smpl->ctx;
|
||||
|
@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
|
||||
return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
||||
auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
|
||||
|
||||
result_ctx->mu = ctx->mu;
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_mirostat *) smpl->ctx;
|
||||
|
@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
|
||||
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
||||
|
||||
auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
|
||||
|
||||
result_ctx->mu = ctx->mu;
|
||||
result_ctx->rng = ctx->rng;
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
|
||||
|
@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
|||
ctx->grammar = grammar_new;
|
||||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx;
|
||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
|
||||
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
||||
|
||||
auto * ctx_dst = (llama_sampler_grammar *) result->ctx;
|
||||
if (ctx_src->grammar) {
|
||||
ctx_dst->grammar_str = ctx_src->grammar_str;
|
||||
ctx_dst->grammar_root = ctx_src->grammar_root;
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_grammar *) result->ctx;
|
||||
|
||||
ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar);
|
||||
if (ctx->grammar) {
|
||||
result_ctx->grammar_str = ctx->grammar_str;
|
||||
result_ctx->grammar_root = ctx->grammar_root;
|
||||
|
||||
result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
|
|||
ctx->prev.clear();
|
||||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx;
|
||||
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_penalties(
|
||||
ctx_src->n_vocab,
|
||||
ctx_src->special_eos_id,
|
||||
ctx_src->linefeed_id,
|
||||
ctx_src->penalty_last_n,
|
||||
ctx_src->penalty_repeat,
|
||||
ctx_src->penalty_freq,
|
||||
ctx_src->penalty_present,
|
||||
ctx_src->penalize_nl,
|
||||
ctx_src->ignore_eos);
|
||||
ctx->n_vocab,
|
||||
ctx->special_eos_id,
|
||||
ctx->linefeed_id,
|
||||
ctx->penalty_last_n,
|
||||
ctx->penalty_repeat,
|
||||
ctx->penalty_freq,
|
||||
ctx->penalty_present,
|
||||
ctx->penalize_nl,
|
||||
ctx->ignore_eos);
|
||||
|
||||
auto * ctx_dst = (llama_sampler_penalties *) result->ctx;
|
||||
ctx_dst->prev = ctx_src->prev;
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_penalties *) result->ctx;
|
||||
|
||||
result_ctx->prev = ctx->prev;
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
|
@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
|||
},
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx;
|
||||
return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
|
||||
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
|
||||
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_logit_bias *) smpl->ctx;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue