sampling : remove redundant indirection calls

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-06 12:39:43 +03:00
parent 809bdcf767
commit b448c753b9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 193 additions and 304 deletions

View file

@ -992,9 +992,9 @@ extern "C" {
// //
// Sampling API // Sampling API
// //
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU). // In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// //
// TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model // TODO: in the future, the entire API that uses llama_model should start using llama_vocab
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
@ -1045,16 +1045,16 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.

View file

@ -432,6 +432,167 @@ void llama_sampler_penalties_impl(
cur_p->sorted = false; cur_p->sorted = false;
} }
// llama_sampler API
const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
}
return smpl->iface->name(smpl);
}
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
if (smpl->iface->accept) {
smpl->iface->accept(smpl, token);
}
}
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
GGML_ASSERT(smpl->iface->apply);
smpl->iface->apply(smpl, cur_p);
}
void llama_sampler_reset(struct llama_sampler * smpl) {
if (smpl->iface->reset) {
smpl->iface->reset(smpl);
}
}
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
if (smpl->iface->clone) {
return smpl->iface->clone(smpl);
}
if (smpl->ctx == nullptr) {
return new llama_sampler {
/* .iface = */ smpl->iface,
/* .ctx = */ nullptr,
};
}
GGML_ABORT("the sampler does not support cloning");
}
void llama_sampler_free(struct llama_sampler * smpl) {
if (smpl == nullptr) {
return;
}
if (smpl->iface->free) {
smpl->iface->free(smpl);
}
delete smpl;
}
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
// TODO: do not allocate each time
std::vector<llama_token_data> cur(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(smpl, &cur_p);
return cur_p.data[cur_p.selected].id;
}
// sampler chain
static struct llama_sampler_i llama_sampler_chain_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
time_meas tm(chain->t_sample_us, chain->params.no_timing);
for (auto * smpl : chain->samplers) {
llama_sampler_accept(smpl, token);
}
chain->n_sample++;
},
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
time_meas tm(chain->t_sample_us, chain->params.no_timing);
for (auto * smpl : chain->samplers) {
llama_sampler_apply(smpl, cur_p);
}
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_reset(smpl);
}
chain->t_sample_us = 0;
chain->n_sample = 0;
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
auto * result = llama_sampler_chain_init(chain_src->params);
for (auto * smpl : chain_src->samplers) {
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
}
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_free(smpl);
}
delete chain;
},
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler {
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
auto * p = (llama_sampler_chain *) chain->ctx;
p->samplers.push_back(smpl);
}
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
const auto * p = (const llama_sampler_chain *) chain->ctx;
if (i < 0 || i >= (int32_t) p->samplers.size()) {
return nullptr;
}
return p->samplers[i];
}
int llama_sampler_chain_n(const struct llama_sampler * chain) {
const auto * p = (const llama_sampler_chain *) chain->ctx;
return p->samplers.size();
}
// //
// samplers // samplers
// //
@ -454,7 +615,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
/* .free = */ nullptr, /* .free = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_greedy_impl() { struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_greedy_i, /* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr, /* .ctx = */ nullptr,
@ -481,14 +642,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx;
return llama_sampler_init_dist_impl(ctx->seed); return llama_sampler_init_dist(ctx->seed);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_dist *) smpl->ctx; delete (llama_sampler_context_dist *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i, /* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_context_dist { /* .ctx = */ new llama_sampler_context_dist {
@ -512,7 +673,7 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
/* .free = */ nullptr, /* .free = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_softmax_impl() { struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_softmax_i, /* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr, /* .ctx = */ nullptr,
@ -535,14 +696,14 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx;
return llama_sampler_init_top_k_impl(ctx->k); return llama_sampler_init_top_k(ctx->k);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_top_k *) smpl->ctx; delete (llama_sampler_context_top_k *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) { struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_top_k_i, /* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_context_top_k { /* .ctx = */ new llama_sampler_context_top_k {
@ -568,14 +729,14 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx;
return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep); return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_top_p *) smpl->ctx; delete (llama_sampler_context_top_p *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_top_p_i, /* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_context_top_p { /* .ctx = */ new llama_sampler_context_top_p {
@ -602,14 +763,14 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx;
return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep); return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_min_p *) smpl->ctx; delete (llama_sampler_context_min_p *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_min_p_i, /* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_context_min_p { /* .ctx = */ new llama_sampler_context_min_p {
@ -636,14 +797,14 @@ static struct llama_sampler_i llama_sampler_tail_free_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx;
return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep); return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_tail_free *) smpl->ctx; delete (llama_sampler_context_tail_free *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) { struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_tail_free_i, /* .iface = */ &llama_sampler_tail_free_i,
/* .ctx = */ new llama_sampler_context_tail_free { /* .ctx = */ new llama_sampler_context_tail_free {
@ -670,14 +831,14 @@ static struct llama_sampler_i llama_sampler_typical_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx;
return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep); return llama_sampler_init_typical(ctx->p, ctx->min_keep);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_typical *) smpl->ctx; delete (llama_sampler_context_typical *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_typical_i, /* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_context_typical { /* .ctx = */ new llama_sampler_context_typical {
@ -703,14 +864,14 @@ static struct llama_sampler_i llama_sampler_temp_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx;
return llama_sampler_init_temp_impl(ctx->temp); return llama_sampler_init_temp(ctx->temp);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_temp *) smpl->ctx; delete (llama_sampler_context_temp *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_temp_impl(float temp) { struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_temp_i, /* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_context_temp { /* .ctx = */ new llama_sampler_context_temp {
@ -744,14 +905,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx;
return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_temp_ext *) smpl->ctx; delete (llama_sampler_context_temp_ext *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) { struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_temp_ext_i, /* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_context_temp_ext { /* .ctx = */ new llama_sampler_context_temp_ext {
@ -900,14 +1061,14 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
}, },
/* .clone = */ [](const struct llama_sampler * smpl) { /* .clone = */ [](const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta); return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
}, },
/* .free = */ [](struct llama_sampler * smpl) { /* .free = */ [](struct llama_sampler * smpl) {
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
}, },
}; };
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i, /* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_context_mirostat_v2 { /* .ctx = */ new llama_sampler_context_mirostat_v2 {
@ -1146,143 +1307,3 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
}, },
}; };
} }
// sampler chain
static struct llama_sampler_i llama_sampler_chain_i = {
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
time_meas tm(chain->t_sample_us, chain->params.no_timing);
for (auto * smpl : chain->samplers) {
llama_sampler_accept_impl(*smpl, token);
}
chain->n_sample++;
},
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
time_meas tm(chain->t_sample_us, chain->params.no_timing);
for (auto * smpl : chain->samplers) {
llama_sampler_apply_impl(*smpl, cur_p);
}
},
/* .reset = */ [](struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_reset_impl(*smpl);
}
chain->t_sample_us = 0;
chain->n_sample = 0;
},
/* .clone = */ [](const struct llama_sampler * smpl) {
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
auto * result = llama_sampler_chain_init_impl(chain_src->params);
auto * chain_dst = (llama_sampler_chain *) result->ctx;
for (auto * smpl : chain_src->samplers) {
llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl));
}
return result;
},
/* .free = */ [](struct llama_sampler * smpl) {
auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) {
llama_sampler_free_impl(smpl);
}
delete chain;
},
};
struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) {
return new llama_sampler {
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) {
chain.samplers.push_back(smpl);
}
struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) {
if (i < 0 || i >= (int32_t) chain.samplers.size()) {
return nullptr;
}
return chain.samplers[i];
}
int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) {
return chain.samplers.size();
}
////////////////////////////////////////
const char * llama_sampler_name_impl(const struct llama_sampler & smpl) {
if (!smpl.iface) {
return "(null)";
}
return smpl.iface->name(&smpl);
}
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
if (smpl.iface->accept) {
smpl.iface->accept(&smpl, token);
}
}
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) {
GGML_ASSERT(smpl.iface->apply);
smpl.iface->apply(&smpl, cur_p);
}
void llama_sampler_reset_impl(struct llama_sampler & smpl) {
if (smpl.iface->reset) {
smpl.iface->reset(&smpl);
}
}
struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) {
if (smpl.iface->clone) {
return smpl.iface->clone(&smpl);
}
if (smpl.ctx == nullptr) {
return new llama_sampler {
/* .iface = */ smpl.iface,
/* .ctx = */ nullptr,
};
}
GGML_ABORT("the sampler does not support cloning");
}
void llama_sampler_free_impl(struct llama_sampler * smpl) {
if (smpl == nullptr) {
return;
}
if (smpl->iface->free) {
smpl->iface->free(smpl);
}
delete smpl;
}

View file

@ -7,15 +7,6 @@
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
// samplers
const char * llama_sampler_name_impl (const struct llama_sampler & smpl);
void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
void llama_sampler_reset_impl ( struct llama_sampler & smpl);
struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl);
void llama_sampler_free_impl ( struct llama_sampler * smpl);
// sampler chain // sampler chain
struct llama_sampler_chain { struct llama_sampler_chain {
@ -30,11 +21,6 @@ struct llama_sampler_chain {
mutable int32_t n_sample; mutable int32_t n_sample;
}; };
struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params);
void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl);
struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i);
int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain);
using llama_token_cnt = std::unordered_map<llama_token, int>; using llama_token_cnt = std::unordered_map<llama_token, int>;
// TODO: tmp exposed until test-sampling is fixed // TODO: tmp exposed until test-sampling is fixed
@ -45,17 +31,6 @@ void llama_sampler_penalties_impl(
float penalty_freq, float penalty_freq,
float penalty_present); float penalty_present);
struct llama_sampler * llama_sampler_init_greedy_impl ();
struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed);
struct llama_sampler * llama_sampler_init_softmax_impl ();
struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k);
struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep);
struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep);
struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep);
struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep);
struct llama_sampler * llama_sampler_init_temp_impl (float t);
struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent);
struct llama_sampler * llama_sampler_init_mirostat_impl( struct llama_sampler * llama_sampler_init_mirostat_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
uint32_t seed, uint32_t seed,
@ -63,11 +38,6 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(
float eta, float eta,
int32_t m); int32_t m);
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(
uint32_t seed,
float tau,
float eta);
struct llama_sampler * llama_sampler_init_grammar_impl( struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
const char * grammar_str, const char * grammar_str,
@ -82,7 +52,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(
bool penalize_nl, bool penalize_nl,
bool ignore_eos); bool ignore_eos);
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias); const llama_logit_bias * logit_bias);

View file

@ -20592,102 +20592,17 @@ int32_t llama_chat_apply_template(
// sampling // sampling
// //
const char * llama_sampler_name(const struct llama_sampler * smpl) { // TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
return llama_sampler_name_impl(*smpl);
}
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
llama_sampler_accept_impl(*smpl, token);
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
llama_sampler_apply_impl(*smpl, cur_p);
}
void llama_sampler_reset(struct llama_sampler * smpl) {
llama_sampler_reset_impl(*smpl);
}
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
return llama_sampler_clone_impl(*smpl);
}
void llama_sampler_free(struct llama_sampler * smpl) {
if (smpl == nullptr) {
return;
}
llama_sampler_free_impl(smpl);
}
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return llama_sampler_chain_init_impl(params);
}
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl);
}
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i);
}
int llama_sampler_chain_n(const struct llama_sampler * chain) {
return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx);
}
struct llama_sampler * llama_sampler_init_greedy(void) {
return llama_sampler_init_greedy_impl();
}
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return llama_sampler_init_dist_impl(seed);
}
struct llama_sampler * llama_sampler_init_softmax(void) {
return llama_sampler_init_softmax_impl();
}
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return llama_sampler_init_top_k_impl(k);
}
struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) {
return llama_sampler_init_top_p_impl(p, min_keep);
}
struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) {
return llama_sampler_init_min_p_impl(p, min_keep);
}
struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) {
return llama_sampler_init_tail_free_impl(z, min_keep);
}
struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) {
return llama_sampler_init_typical_impl(p, min_keep);
}
struct llama_sampler * llama_sampler_init_temp(float temp) {
return llama_sampler_init_temp_impl(temp);
}
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
}
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) {
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m);
} }
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { // TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
return llama_sampler_init_mirostat_v2_impl(seed, tau, eta);
}
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
} }
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_penalties( struct llama_sampler * llama_sampler_init_penalties(
const struct llama_model * model, const struct llama_model * model,
int32_t penalty_last_n, int32_t penalty_last_n,
@ -20699,31 +20614,14 @@ struct llama_sampler * llama_sampler_init_penalties(
return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos);
} }
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( // TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_logit_bias(
const struct llama_model * model, const struct llama_model * model,
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias) { const llama_logit_bias * logit_bias) {
return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias);
} }
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
// TODO: do not allocate each time
std::vector<llama_token_data> cur(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(smpl, &cur_p);
return cur_p.data[cur_p.selected].id;
}
// //
// model split // model split
// //