sampling : add adaptive temperature sampler
This commit is contained in:
parent
668750357e
commit
4f80618716
7 changed files with 93 additions and 11 deletions
|
@ -1072,6 +1072,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.sparams.dynatemp_exponent = std::stof(value);
|
params.sparams.dynatemp_exponent = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--temp-adaptive"},
|
||||||
|
"ignore arguments for temp and dynatemp, and automatically set temperature based on entropy",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.sparams.temp_adaptive = true;
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--mirostat"}, "N",
|
{"--mirostat"}, "N",
|
||||||
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
|
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
|
||||||
|
|
|
@ -132,6 +132,7 @@ struct common_sampler_params {
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool temp_adaptive = false; // enables automatic adaptive setting of temperature
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
|
|
@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f, temp_adaptive = %d\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, temp_adaptive,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
|
@ -200,13 +200,19 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free (params.tfs_z, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
|
{
|
||||||
|
if (!params.temp_adaptive) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
|
} else {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_adaptive());
|
||||||
|
}
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||||
|
|
|
@ -812,6 +812,7 @@ struct server_context {
|
||||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||||
|
slot.sparams.temp_adaptive = json_value(data, "temp_adaptive", default_sparams.temp_adaptive);
|
||||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
|
@ -1142,6 +1143,7 @@ struct server_context {
|
||||||
{"seed", slot.sparams.seed},
|
{"seed", slot.sparams.seed},
|
||||||
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
|
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
|
||||||
{"temperature", slot.sparams.temp},
|
{"temperature", slot.sparams.temp},
|
||||||
|
{"temp_adaptive", slot.sparams.temp_adaptive},
|
||||||
{"dynatemp_range", slot.sparams.dynatemp_range},
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||||
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
||||||
{"top_k", slot.sparams.top_k},
|
{"top_k", slot.sparams.top_k},
|
||||||
|
|
|
@ -1099,6 +1099,9 @@ extern "C" {
|
||||||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||||
|
|
||||||
|
/// @details Adaptive temperature implementation described in the paper https://arxiv.org/abs/2410.01104.
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp_adaptive (void);
|
||||||
|
|
||||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||||
|
|
||||||
|
|
|
@ -1082,6 +1082,55 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// temp-adaptive
|
||||||
|
|
||||||
|
static const char * llama_sampler_temp_adaptive_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "temp-adaptive";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_temp_adaptive_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
|
// calculate entropy
|
||||||
|
float entropy = 0.0f;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p + 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate beta
|
||||||
|
float beta = 0.0f;
|
||||||
|
if (entropy > 0.5) { // don't overcorrect low-entropy heads
|
||||||
|
beta = -0.037 * powf(entropy, 4)
|
||||||
|
+ 0.481 * powf(entropy, 3)
|
||||||
|
+ -2.3 * powf(entropy, 2)
|
||||||
|
+ 4.917 * entropy
|
||||||
|
+ -1.791;
|
||||||
|
// never increase entropy
|
||||||
|
beta = (beta < 1.0) ? 1.0 : beta;
|
||||||
|
} else {
|
||||||
|
beta = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// beta = 1 / temp
|
||||||
|
llama_sampler_temp_impl(cur_p, 1.0f / beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_temp_adaptive_i = {
|
||||||
|
/* .name = */ llama_sampler_temp_adaptive_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sampler_temp_adaptive_apply,
|
||||||
|
/* .reset = */ nullptr,
|
||||||
|
/* .clone = */ nullptr,
|
||||||
|
/* .free = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_temp_adaptive() {
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ &llama_sampler_temp_adaptive_i,
|
||||||
|
/* .ctx = */ nullptr,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// xtc
|
// xtc
|
||||||
|
|
||||||
struct llama_sampler_xtc {
|
struct llama_sampler_xtc {
|
||||||
|
|
|
@ -72,6 +72,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
|
||||||
tester.check();
|
tester.check();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_temp_adaptive(const std::vector<float> & probs, const std::vector<float> & probs_expected) {
|
||||||
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
tester.apply(llama_sampler_init_temp_adaptive());
|
||||||
|
tester.apply(llama_sampler_init_dist(0));
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
|
||||||
|
tester.check();
|
||||||
|
}
|
||||||
|
|
||||||
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
||||||
sampler_tester tester(probs, probs_expected);
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
|
@ -311,7 +322,10 @@ int main(void) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
test_temp({0.4f, 0.3f, 0.2f, 0.1f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||||
|
|
||||||
|
test_temp_adaptive({0.1f, 0.2f, 0.3f, 0.4f}, {0.488836, 0.304651, 0.156445, 0.050068});
|
||||||
|
test_temp_adaptive({0.7f, 0.1f, 0.1f, 0.1f}, {0.764643, 0.078452, 0.078452, 0.078452});
|
||||||
|
|
||||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
||||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue