exposed exponent_val in dynamic temp sampler

This commit is contained in:
l3utterfly 2024-01-19 10:05:59 +09:00
parent e1f91ae24a
commit a2c94ae5be
4 changed files with 6 additions and 5 deletions

View file

@ -130,6 +130,7 @@ static void sampler_queue(
const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
@ -154,7 +155,7 @@ static void sampler_queue(
dynatemp_min = dynatemp_min<0?0:dynatemp_min;
dynatemp_max = dynatemp_max<0?0:dynatemp_max;
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
}
else
{

View file

@ -19,6 +19,7 @@ typedef struct llama_sampling_params {
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.10f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled

View file

@ -7783,13 +7783,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
}
}
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp = 0, float max_temp = 2.0f) {
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
const int64_t t_start_sample_us = ggml_time_us();
llama_sample_softmax(ctx, candidates_p);
float exponent_val = 1.0f;
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
for (size_t i = 0; i < candidates_p->size; ++i) {

View file

@ -779,7 +779,8 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates_p,
float min_temp,
float max_temp);
float max_temp,
float exponent_val);
LLAMA_API void llama_sample_temp(
struct llama_context * ctx,