sampling : remove _context suffix [no ci]

This commit is contained in:
Georgi Gerganov 2024-09-06 12:52:19 +03:00
parent b448c753b9
commit 5ab52c1f64
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -624,7 +624,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
// dist
struct llama_sampler_context_dist {
struct llama_sampler_dist {
const uint32_t seed;
std::mt19937 rng;
@ -636,23 +636,23 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_context_dist *) smpl->ctx;
auto * ctx = (llama_sampler_dist *) smpl->ctx;
cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx;
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
return llama_sampler_init_dist(ctx->seed);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_dist *) smpl->ctx;
delete (llama_sampler_dist *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_context_dist {
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .rng = */ std::mt19937(seed),
/* .probs = */ {},
@ -682,7 +682,7 @@ struct llama_sampler * llama_sampler_init_softmax() {
// top-k
struct llama_sampler_context_top_k {
struct llama_sampler_top_k {
const int32_t k;
};
@ -690,23 +690,23 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_top_k *) smpl->ctx;
const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
llama_sampler_top_k_impl(cur_p, ctx->k);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx;
const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
return llama_sampler_init_top_k(ctx->k);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_top_k *) smpl->ctx;
delete (llama_sampler_top_k *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler {
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_context_top_k {
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
},
};
@ -714,7 +714,7 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
// top-p
struct llama_sampler_context_top_p {
struct llama_sampler_top_p {
const float p;
const size_t min_keep;
};
@ -723,23 +723,23 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_top_p *) smpl->ctx;
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx;
const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_top_p *) smpl->ctx;
delete (llama_sampler_top_p *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler {
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_context_top_p {
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
@ -748,7 +748,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
// min-p
struct llama_sampler_context_min_p {
struct llama_sampler_min_p {
const float p;
const size_t min_keep;
};
@ -757,23 +757,23 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_min_p *) smpl->ctx;
const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx;
const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_min_p *) smpl->ctx;
delete (llama_sampler_min_p *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler {
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_context_min_p {
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
@ -782,7 +782,7 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
// tail-free
struct llama_sampler_context_tail_free {
struct llama_sampler_tail_free {
const float z;
const size_t min_keep;
};
@ -791,23 +791,23 @@ static struct llama_sampler_i llama_sampler_tail_free_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_tail_free *) smpl->ctx;
const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx;
const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_tail_free *) smpl->ctx;
delete (llama_sampler_tail_free *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
return new llama_sampler {
/* .iface = */ &llama_sampler_tail_free_i,
/* .ctx = */ new llama_sampler_context_tail_free {
/* .ctx = */ new llama_sampler_tail_free {
/* .z = */ z,
/*. min_keep = */ min_keep,
},
@ -816,7 +816,7 @@ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
// typical
struct llama_sampler_context_typical {
struct llama_sampler_typical {
const float p;
const size_t min_keep;
};
@ -825,23 +825,23 @@ static struct llama_sampler_i llama_sampler_typical_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_typical *) smpl->ctx;
const auto * ctx = (llama_sampler_typical *) smpl->ctx;
llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx;
const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_typical *) smpl->ctx;
delete (llama_sampler_typical *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler {
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_context_typical {
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
@ -850,7 +850,7 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
// temp
struct llama_sampler_context_temp {
struct llama_sampler_temp {
const float temp;
};
@ -858,23 +858,23 @@ static struct llama_sampler_i llama_sampler_temp_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_temp *) smpl->ctx;
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
llama_sampler_temp_impl(cur_p, ctx->temp);
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx;
const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
return llama_sampler_init_temp(ctx->temp);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_temp *) smpl->ctx;
delete (llama_sampler_temp *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler {
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_context_temp {
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
},
};
@ -882,7 +882,7 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
// temp-ext
struct llama_sampler_context_temp_ext {
struct llama_sampler_temp_ext {
const float temp;
const float delta;
const float exponent;
@ -892,7 +892,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_temp_ext *) smpl->ctx;
const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
if (ctx->delta > 0) {
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
const float temp_max = ctx->temp + ctx->delta;
@ -904,18 +904,18 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx;
const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_temp_ext *) smpl->ctx;
delete (llama_sampler_temp_ext *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler {
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_context_temp_ext {
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
@ -925,7 +925,7 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
// mirostat
struct llama_sampler_context_mirostat {
struct llama_sampler_mirostat {
const struct llama_vocab * vocab;
const uint32_t seed;
@ -946,7 +946,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
@ -980,23 +980,23 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
ctx->mu = ctx->mu - ctx->eta * e;
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx;
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx;
const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_mirostat *) smpl->ctx;
delete (llama_sampler_mirostat *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_context_mirostat {
/* .ctx = */ new llama_sampler_mirostat {
/* .vocab = */ &vocab,
/* .seed = */ seed,
/* .tau = */ tau,
@ -1011,7 +1011,7 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab
// mirostat v2
struct llama_sampler_context_mirostat_v2 {
struct llama_sampler_mirostat_v2 {
const uint32_t seed;
const float tau;
@ -1028,7 +1028,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
@ -1055,23 +1055,23 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
ctx->mu = ctx->mu - ctx->eta * e;
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx;
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
delete (llama_sampler_mirostat_v2 *) smpl->ctx;
},
};
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
@ -1084,7 +1084,7 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
// grammar
struct llama_sampler_context_grammar {
struct llama_sampler_grammar {
const struct llama_vocab * vocab;
std::string grammar_str;
@ -1096,19 +1096,19 @@ struct llama_sampler_context_grammar {
static struct llama_sampler_i llama_sampler_grammar_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; },
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx;
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) {
llama_grammar_accept_impl(*ctx->grammar, token);
}
},
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx;
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) {
llama_sampler_grammar_impl(cur_p, *ctx->grammar);
}
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_context_grammar *) smpl->ctx;
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (!ctx->grammar) {
return;
}
@ -1119,11 +1119,11 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
ctx->grammar = grammar_new;
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_context_grammar *) smpl->ctx;
const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr);
auto * ctx_dst = (llama_sampler_context_grammar *) result->ctx;
auto * ctx_dst = (llama_sampler_grammar *) result->ctx;
if (ctx_src->grammar) {
ctx_dst->grammar_str = ctx_src->grammar_str;
ctx_dst->grammar_root = ctx_src->grammar_root;
@ -1134,7 +1134,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx;
const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) {
llama_grammar_free_impl(ctx->grammar);
@ -1145,7 +1145,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
};
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_context_grammar;
auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = {
@ -1171,7 +1171,7 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
// penalties
struct llama_sampler_context_penalties {
struct llama_sampler_penalties {
const struct llama_vocab * vocab;
const int32_t penalty_last_n;
@ -1188,11 +1188,11 @@ struct llama_sampler_context_penalties {
static struct llama_sampler_i llama_sampler_penalties_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; },
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_context_penalties *) smpl->ctx;
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.push_back(token);
},
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_context_penalties *) smpl->ctx;
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary");
@ -1222,11 +1222,11 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
}
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_context_penalties *) smpl->ctx;
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear();
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_context_penalties *) smpl->ctx;
const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties_impl(
*ctx_src->vocab,
ctx_src->penalty_last_n,
@ -1236,13 +1236,13 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
ctx_src->penalize_nl,
ctx_src->ignore_eos);
auto * ctx_dst = (llama_sampler_context_penalties *) result->ctx;
auto * ctx_dst = (llama_sampler_penalties *) result->ctx;
ctx_dst->prev = ctx_src->prev;
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_penalties *) smpl->ctx;
delete (llama_sampler_penalties *) smpl->ctx;
},
};
@ -1252,7 +1252,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca
return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_context_penalties {
/* .ctx = */ new llama_sampler_penalties {
/* .vocab = */ &vocab,
/* .penalty_last_n = */ penalty_last_n,
/* .penalty_repeat = */ penalty_repeat,
@ -1267,7 +1267,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca
// logit-bias
struct llama_sampler_context_logit_bias {
struct llama_sampler_logit_bias {
const struct llama_vocab * vocab;
std::vector<llama_logit_bias> logit_bias;
@ -1277,7 +1277,7 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; },
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_context_logit_bias *) smpl->ctx;
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary");
@ -1287,11 +1287,11 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = {
},
/* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx_src = (const llama_sampler_context_logit_bias *) smpl->ctx;
const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data());
},
/* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_logit_bias *) smpl->ctx;
delete (llama_sampler_logit_bias *) smpl->ctx;
},
};
@ -1301,7 +1301,7 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
const llama_logit_bias * logit_bias) {
return new llama_sampler {
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_context_logit_bias {
/* .ctx = */ new llama_sampler_logit_bias {
/* .vocab = */ &vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
},