sampling : allow passing m to mirostat sampler
This commit is contained in:
parent
8c972b69c1
commit
809bdcf767
3 changed files with 5 additions and 4 deletions
|
@ -206,7 +206,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
|
|
|
@ -1070,7 +1070,8 @@ extern "C" {
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
uint32_t seed,
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta);
|
float eta,
|
||||||
|
int32_t m);
|
||||||
|
|
||||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
|
|
|
@ -20676,8 +20676,8 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
|
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) {
|
||||||
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100);
|
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue