sampling : remove redundant indirection calls
ggml-ci
This commit is contained in:
parent
809bdcf767
commit
b448c753b9
4 changed files with 193 additions and 304 deletions
|
@ -992,9 +992,9 @@ extern "C" {
|
|||
//
|
||||
// Sampling API
|
||||
//
|
||||
// In the future, it will be utilized offload the sampling to the backends (e.g. GPU).
|
||||
// In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
//
|
||||
// TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model
|
||||
// TODO: in the future, the entire API that uses llama_model should start using llama_vocab
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
|
||||
|
@ -1045,16 +1045,16 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||
|
||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
||||
|
||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
||||
|
||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
|
||||
|
||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
||||
|
||||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||
|
|
|
@ -432,6 +432,167 @@ void llama_sampler_penalties_impl(
|
|||
cur_p->sorted = false;
|
||||
}
|
||||
|
||||
// llama_sampler API
|
||||
|
||||
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||
if (!smpl->iface) {
|
||||
return "(null)";
|
||||
}
|
||||
|
||||
return smpl->iface->name(smpl);
|
||||
}
|
||||
|
||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
if (smpl->iface->accept) {
|
||||
smpl->iface->accept(smpl, token);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(smpl->iface->apply);
|
||||
smpl->iface->apply(smpl, cur_p);
|
||||
}
|
||||
|
||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||
if (smpl->iface->reset) {
|
||||
smpl->iface->reset(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
||||
if (smpl->iface->clone) {
|
||||
return smpl->iface->clone(smpl);
|
||||
}
|
||||
|
||||
if (smpl->ctx == nullptr) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ smpl->iface,
|
||||
/* .ctx = */ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
GGML_ABORT("the sampler does not support cloning");
|
||||
}
|
||||
|
||||
void llama_sampler_free(struct llama_sampler * smpl) {
|
||||
if (smpl == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (smpl->iface->free) {
|
||||
smpl->iface->free(smpl);
|
||||
}
|
||||
|
||||
delete smpl;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
// TODO: do not allocate each time
|
||||
std::vector<llama_token_data> cur(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
llama_sampler_apply(smpl, &cur_p);
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
// sampler chain
|
||||
|
||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_timing);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
|
||||
chain->n_sample++;
|
||||
},
|
||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_timing);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_apply(smpl, cur_p);
|
||||
}
|
||||
},
|
||||
/* .reset = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_reset(smpl);
|
||||
}
|
||||
|
||||
chain->t_sample_us = 0;
|
||||
chain->n_sample = 0;
|
||||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_chain_init(chain_src->params);
|
||||
|
||||
for (auto * smpl : chain_src->samplers) {
|
||||
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_free(smpl);
|
||||
}
|
||||
|
||||
delete chain;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_chain_i,
|
||||
/* .ctx = */ new llama_sampler_chain {
|
||||
/* .params = */ params,
|
||||
/* .samplers = */ {},
|
||||
/* .t_sample_us = */ 0,
|
||||
/* .n_sample = */ 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||
p->samplers.push_back(smpl);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||
|
||||
if (i < 0 || i >= (int32_t) p->samplers.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return p->samplers[i];
|
||||
}
|
||||
|
||||
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||
|
||||
return p->samplers.size();
|
||||
}
|
||||
|
||||
//
|
||||
// samplers
|
||||
//
|
||||
|
@ -454,7 +615,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
|||
/* .free = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_greedy_impl() {
|
||||
struct llama_sampler * llama_sampler_init_greedy() {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_greedy_i,
|
||||
/* .ctx = */ nullptr,
|
||||
|
@ -481,14 +642,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx;
|
||||
return llama_sampler_init_dist_impl(ctx->seed);
|
||||
return llama_sampler_init_dist(ctx->seed);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_dist *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) {
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_dist_i,
|
||||
/* .ctx = */ new llama_sampler_context_dist {
|
||||
|
@ -512,7 +673,7 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
|
|||
/* .free = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_softmax_impl() {
|
||||
struct llama_sampler * llama_sampler_init_softmax() {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_softmax_i,
|
||||
/* .ctx = */ nullptr,
|
||||
|
@ -535,14 +696,14 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx;
|
||||
return llama_sampler_init_top_k_impl(ctx->k);
|
||||
return llama_sampler_init_top_k(ctx->k);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_top_k *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) {
|
||||
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_top_k_i,
|
||||
/* .ctx = */ new llama_sampler_context_top_k {
|
||||
|
@ -568,14 +729,14 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx;
|
||||
return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep);
|
||||
return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_top_p *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) {
|
||||
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_top_p_i,
|
||||
/* .ctx = */ new llama_sampler_context_top_p {
|
||||
|
@ -602,14 +763,14 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx;
|
||||
return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep);
|
||||
return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_min_p *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) {
|
||||
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_min_p_i,
|
||||
/* .ctx = */ new llama_sampler_context_min_p {
|
||||
|
@ -636,14 +797,14 @@ static struct llama_sampler_i llama_sampler_tail_free_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx;
|
||||
return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep);
|
||||
return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_tail_free *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) {
|
||||
struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_tail_free_i,
|
||||
/* .ctx = */ new llama_sampler_context_tail_free {
|
||||
|
@ -670,14 +831,14 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx;
|
||||
return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep);
|
||||
return llama_sampler_init_typical(ctx->p, ctx->min_keep);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_typical *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) {
|
||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_typical_i,
|
||||
/* .ctx = */ new llama_sampler_context_typical {
|
||||
|
@ -703,14 +864,14 @@ static struct llama_sampler_i llama_sampler_temp_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx;
|
||||
return llama_sampler_init_temp_impl(ctx->temp);
|
||||
return llama_sampler_init_temp(ctx->temp);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_temp *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_impl(float temp) {
|
||||
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_temp_i,
|
||||
/* .ctx = */ new llama_sampler_context_temp {
|
||||
|
@ -744,14 +905,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|||
/* .reset = */ nullptr,
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx;
|
||||
return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent);
|
||||
return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_temp_ext *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) {
|
||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||
/* .ctx = */ new llama_sampler_context_temp_ext {
|
||||
|
@ -900,14 +1061,14 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||
return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta);
|
||||
return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) {
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
||||
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
|
||||
|
@ -1146,143 +1307,3 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
|
|||
},
|
||||
};
|
||||
}
|
||||
|
||||
// sampler chain
|
||||
|
||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||
/* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
|
||||
/* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_timing);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_accept_impl(*smpl, token);
|
||||
}
|
||||
|
||||
chain->n_sample++;
|
||||
},
|
||||
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
time_meas tm(chain->t_sample_us, chain->params.no_timing);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_apply_impl(*smpl, cur_p);
|
||||
}
|
||||
},
|
||||
/* .reset = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_reset_impl(*smpl);
|
||||
}
|
||||
|
||||
chain->t_sample_us = 0;
|
||||
chain->n_sample = 0;
|
||||
},
|
||||
/* .clone = */ [](const struct llama_sampler * smpl) {
|
||||
const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_chain_init_impl(chain_src->params);
|
||||
|
||||
auto * chain_dst = (llama_sampler_chain *) result->ctx;
|
||||
for (auto * smpl : chain_src->samplers) {
|
||||
llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl));
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
/* .free = */ [](struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
llama_sampler_free_impl(smpl);
|
||||
}
|
||||
|
||||
delete chain;
|
||||
},
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_chain_i,
|
||||
/* .ctx = */ new llama_sampler_chain {
|
||||
/* .params = */ params,
|
||||
/* .samplers = */ {},
|
||||
/* .t_sample_us = */ 0,
|
||||
/* .n_sample = */ 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) {
|
||||
chain.samplers.push_back(smpl);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) {
|
||||
if (i < 0 || i >= (int32_t) chain.samplers.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return chain.samplers[i];
|
||||
}
|
||||
|
||||
int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) {
|
||||
return chain.samplers.size();
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////
|
||||
|
||||
const char * llama_sampler_name_impl(const struct llama_sampler & smpl) {
|
||||
if (!smpl.iface) {
|
||||
return "(null)";
|
||||
}
|
||||
|
||||
return smpl.iface->name(&smpl);
|
||||
}
|
||||
|
||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
||||
if (smpl.iface->accept) {
|
||||
smpl.iface->accept(&smpl, token);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(smpl.iface->apply);
|
||||
smpl.iface->apply(&smpl, cur_p);
|
||||
}
|
||||
|
||||
void llama_sampler_reset_impl(struct llama_sampler & smpl) {
|
||||
if (smpl.iface->reset) {
|
||||
smpl.iface->reset(&smpl);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) {
|
||||
if (smpl.iface->clone) {
|
||||
return smpl.iface->clone(&smpl);
|
||||
}
|
||||
|
||||
if (smpl.ctx == nullptr) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ smpl.iface,
|
||||
/* .ctx = */ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
GGML_ABORT("the sampler does not support cloning");
|
||||
}
|
||||
|
||||
void llama_sampler_free_impl(struct llama_sampler * smpl) {
|
||||
if (smpl == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (smpl->iface->free) {
|
||||
smpl->iface->free(smpl);
|
||||
}
|
||||
|
||||
delete smpl;
|
||||
}
|
||||
|
|
|
@ -7,15 +7,6 @@
|
|||
struct llama_vocab;
|
||||
struct llama_grammar;
|
||||
|
||||
// samplers
|
||||
|
||||
const char * llama_sampler_name_impl (const struct llama_sampler & smpl);
|
||||
void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token);
|
||||
void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
|
||||
void llama_sampler_reset_impl ( struct llama_sampler & smpl);
|
||||
struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl);
|
||||
void llama_sampler_free_impl ( struct llama_sampler * smpl);
|
||||
|
||||
// sampler chain
|
||||
|
||||
struct llama_sampler_chain {
|
||||
|
@ -30,11 +21,6 @@ struct llama_sampler_chain {
|
|||
mutable int32_t n_sample;
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params);
|
||||
void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl);
|
||||
struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i);
|
||||
int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain);
|
||||
|
||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||
|
||||
// TODO: tmp exposed until test-sampling is fixed
|
||||
|
@ -45,17 +31,6 @@ void llama_sampler_penalties_impl(
|
|||
float penalty_freq,
|
||||
float penalty_present);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_greedy_impl ();
|
||||
struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed);
|
||||
struct llama_sampler * llama_sampler_init_softmax_impl ();
|
||||
struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k);
|
||||
struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep);
|
||||
struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep);
|
||||
struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep);
|
||||
struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep);
|
||||
struct llama_sampler * llama_sampler_init_temp_impl (float t);
|
||||
struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
uint32_t seed,
|
||||
|
@ -63,11 +38,6 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(
|
|||
float eta,
|
||||
int32_t m);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2_impl(
|
||||
uint32_t seed,
|
||||
float tau,
|
||||
float eta);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const char * grammar_str,
|
||||
|
@ -82,7 +52,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(
|
|||
bool penalize_nl,
|
||||
bool ignore_eos);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl(
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
|
112
src/llama.cpp
112
src/llama.cpp
|
@ -20592,102 +20592,17 @@ int32_t llama_chat_apply_template(
|
|||
// sampling
|
||||
//
|
||||
|
||||
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||
return llama_sampler_name_impl(*smpl);
|
||||
}
|
||||
|
||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
llama_sampler_accept_impl(*smpl, token);
|
||||
}
|
||||
|
||||
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
llama_sampler_apply_impl(*smpl, cur_p);
|
||||
}
|
||||
|
||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||
llama_sampler_reset_impl(*smpl);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
||||
return llama_sampler_clone_impl(*smpl);
|
||||
}
|
||||
|
||||
void llama_sampler_free(struct llama_sampler * smpl) {
|
||||
if (smpl == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sampler_free_impl(smpl);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
||||
return llama_sampler_chain_init_impl(params);
|
||||
}
|
||||
|
||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||
llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||
return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i);
|
||||
}
|
||||
|
||||
int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
||||
return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_greedy(void) {
|
||||
return llama_sampler_init_greedy_impl();
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
return llama_sampler_init_dist_impl(seed);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_softmax(void) {
|
||||
return llama_sampler_init_softmax_impl();
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||
return llama_sampler_init_top_k_impl(k);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) {
|
||||
return llama_sampler_init_top_p_impl(p, min_keep);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) {
|
||||
return llama_sampler_init_min_p_impl(p, min_keep);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) {
|
||||
return llama_sampler_init_tail_free_impl(z, min_keep);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) {
|
||||
return llama_sampler_init_typical_impl(p, min_keep);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||
return llama_sampler_init_temp_impl(temp);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
|
||||
}
|
||||
|
||||
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) {
|
||||
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||
return llama_sampler_init_mirostat_v2_impl(seed, tau, eta);
|
||||
}
|
||||
|
||||
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
|
||||
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||
}
|
||||
|
||||
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_penalties(
|
||||
const struct llama_model * model,
|
||||
int32_t penalty_last_n,
|
||||
|
@ -20699,31 +20614,14 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos);
|
||||
}
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
const struct llama_model * model,
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias) {
|
||||
return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias);
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
// TODO: do not allocate each time
|
||||
std::vector<llama_token_data> cur(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
llama_sampler_apply(smpl, &cur_p);
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
//
|
||||
// model split
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue