diff --git a/common/sampling.cpp b/common/sampling.cpp index cf3ee98d4..9964501da 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -206,7 +206,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/include/llama.h b/include/llama.h index 50c89c10f..02c565a3d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1070,7 +1070,8 @@ extern "C" { const struct llama_model * model, uint32_t seed, float tau, - float eta); + float eta, + int32_t m); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama.cpp b/src/llama.cpp index ce9209658..db50a0332 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20676,8 +20676,8 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa return llama_sampler_init_temp_ext_impl(temp, delta, exponent); } -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) { - return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100); +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { + return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); } struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {