sampling : allow passing m to mirostat sampler

This commit is contained in:
Georgi Gerganov 2024-09-06 12:06:00 +03:00
parent 8c972b69c1
commit 809bdcf767
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 5 additions and 4 deletions

View file

@ -206,7 +206,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View file

@ -1070,7 +1070,8 @@ extern "C" {
const struct llama_model * model,
uint32_t seed,
float tau,
float eta);
float eta,
int32_t m);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

View file

@ -20676,8 +20676,8 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
return llama_sampler_init_temp_ext_impl(temp, delta, exponent);
}
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) {
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100);
struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) {
return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m);
}
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {