Smoothing factor backport

This commit is contained in:
kalomaze 2024-04-02 17:10:46 -05:00
parent f87f7b8986
commit b5dbcf6f5e
5 changed files with 23 additions and 6 deletions

View file

@ -132,6 +132,7 @@ static void sampler_queue(
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent; const float dynatemp_exponent = params.dynatemp_exponent;
const float smoothing_factor = params.smoothing_factor;
const int32_t top_k = params.top_k; const int32_t top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float min_p = params.min_p; const float min_p = params.min_p;
@ -147,10 +148,10 @@ static void sampler_queue(
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::TEMPERATURE: case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) { if (dynatemp_range > 0 || smoothing_factor > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent, smoothing_factor);
} else { } else {
llama_sample_temp(ctx_main, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
} }

View file

@ -31,6 +31,7 @@ typedef struct llama_sampling_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float smoothing_factor = 0.0f; // controls the quadratic adjustment in smooth sampling
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled

View file

@ -839,6 +839,7 @@ struct server_context {
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.smoothing_factor = json_value(data, "smoothing_factor", default_sparams.smoothing_factor);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);

View file

@ -12183,7 +12183,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
} }
} }
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { void llama_sample_entropy(struct llama_context* ctx, llama_token_data_array* candidates_p, float min_temp, float max_temp, float exponent_val, float smoothing_factor) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
// no need to do anything if there is only one (or zero) candidates // no need to do anything if there is only one (or zero) candidates
@ -12191,6 +12191,19 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c
return; return;
} }
// Apply smoothing if smoothing_factor is > 0. Do not change base implementation otherwise.
if (smoothing_factor > 0 && candidates_p->size > 1) {
llama_sample_softmax(ctx, candidates_p);
float h = candidates_p->data[0].logit; // Find the maximum logit for h to be added after the transformation
// Apply quadratic transformation using the smoothing_factor
for (size_t i = 0; i < candidates_p->size; ++i) {
float logit_shifted = candidates_p->data[i].logit - h;
candidates_p->data[i].logit = -smoothing_factor * logit_shifted * logit_shifted + h;
}
llama_sample_softmax(ctx, candidates_p);
}
// Calculate maximum possible entropy // Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates_p->size); float max_entropy = -logf(1.0f / candidates_p->size);

View file

@ -864,13 +864,14 @@ extern "C" {
float p, float p,
size_t min_keep); size_t min_keep);
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. /// @details Dynamic temperature implementation + Smooth Sampling implementations wrapped into one function, no research papers available
LLAMA_API void llama_sample_entropy( LLAMA_API void llama_sample_entropy(
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates_p, llama_token_data_array * candidates_p,
float min_temp, float min_temp,
float max_temp, float max_temp,
float exponent_val); float exponent_val,
float smoothing_factor);
LLAMA_API void llama_sample_temp( LLAMA_API void llama_sample_temp(
struct llama_context * ctx, struct llama_context * ctx,