sampling : add adaptive temperature sampler

This commit is contained in:
Michael Coppola 2024-10-26 02:58:07 -04:00
parent 668750357e
commit 4f80618716
7 changed files with 93 additions and 11 deletions

View file

@ -1072,6 +1072,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.dynatemp_exponent = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--temp-adaptive"},
"ignore arguments for temp and dynatemp, and automatically set temperature based on entropy",
[](common_params & params) {
params.sparams.temp_adaptive = true;
}
).set_sparam());
add_opt(common_arg(
{"--mirostat"}, "N",
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"

View file

@ -132,6 +132,7 @@ struct common_sampler_params {
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool temp_adaptive = false; // enables automatic adaptive setting of temperature
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

View file

@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f, temp_adaptive = %d\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, temp_adaptive,
mirostat, mirostat_eta, mirostat_tau);
return std::string(result);
@ -188,28 +188,34 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TFS_Z:
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free (params.tfs_z, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
{
if (!params.temp_adaptive) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
} else {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_adaptive());
}
}
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break;
default:
GGML_ASSERT(false && "unknown sampler type");

View file

@ -812,6 +812,7 @@ struct server_context {
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.temp_adaptive = json_value(data, "temp_adaptive", default_sparams.temp_adaptive);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
@ -1142,6 +1143,7 @@ struct server_context {
{"seed", slot.sparams.seed},
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp},
{"temp_adaptive", slot.sparams.temp_adaptive},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"top_k", slot.sparams.top_k},

View file

@ -1099,6 +1099,9 @@ extern "C" {
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
/// @details Adaptive temperature implementation described in the paper https://arxiv.org/abs/2410.01104.
LLAMA_API struct llama_sampler * llama_sampler_init_temp_adaptive (void);
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);

View file

@ -1082,6 +1082,55 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
};
}
// temp-adaptive
static const char * llama_sampler_temp_adaptive_name(const struct llama_sampler * /*smpl*/) {
return "temp-adaptive";
}
static void llama_sampler_temp_adaptive_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
llama_sampler_softmax_impl(cur_p);
// calculate entropy
float entropy = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p + 1e-9);
}
// calculate beta
float beta = 0.0f;
if (entropy > 0.5) { // don't overcorrect low-entropy heads
beta = -0.037 * powf(entropy, 4)
+ 0.481 * powf(entropy, 3)
+ -2.3 * powf(entropy, 2)
+ 4.917 * entropy
+ -1.791;
// never increase entropy
beta = (beta < 1.0) ? 1.0 : beta;
} else {
beta = 1.0;
}
// beta = 1 / temp
llama_sampler_temp_impl(cur_p, 1.0f / beta);
}
static struct llama_sampler_i llama_sampler_temp_adaptive_i = {
/* .name = */ llama_sampler_temp_adaptive_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_adaptive_apply,
/* .reset = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
};
struct llama_sampler * llama_sampler_init_temp_adaptive() {
return new llama_sampler {
/* .iface = */ &llama_sampler_temp_adaptive_i,
/* .ctx = */ nullptr,
};
}
// xtc
struct llama_sampler_xtc {

View file

@ -72,6 +72,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
tester.check();
}
static void test_temp_adaptive(const std::vector<float> & probs, const std::vector<float> & probs_expected) {
sampler_tester tester(probs, probs_expected);
DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_temp_adaptive());
tester.apply(llama_sampler_init_dist(0));
DUMP(&tester.cur_p);
tester.check();
}
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
sampler_tester tester(probs, probs_expected);
@ -311,7 +322,10 @@ int main(void) {
ggml_time_init();
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
test_temp({0.4f, 0.3f, 0.2f, 0.1f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
test_temp_adaptive({0.1f, 0.2f, 0.3f, 0.4f}, {0.488836, 0.304651, 0.156445, 0.050068});
test_temp_adaptive({0.7f, 0.1f, 0.1f, 0.1f}, {0.764643, 0.078452, 0.078452, 0.078452});
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);