sampling : add adaptive temperature sampler
This commit is contained in:
parent
668750357e
commit
4f80618716
7 changed files with 93 additions and 11 deletions
|
@ -1072,6 +1072,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sparams.dynatemp_exponent = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp-adaptive"},
|
||||
"ignore arguments for temp and dynatemp, and automatically set temperature based on entropy",
|
||||
[](common_params & params) {
|
||||
params.sparams.temp_adaptive = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat"}, "N",
|
||||
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
|
||||
|
|
|
@ -132,6 +132,7 @@ struct common_sampler_params {
|
|||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
bool ignore_eos = false;
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool temp_adaptive = false; // enables automatic adaptive setting of temperature
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
|
|
@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
|
|||
snprintf(result, sizeof(result),
|
||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f, temp_adaptive = %d\n"
|
||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
||||
top_k, tfs_z, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, temp_adaptive,
|
||||
mirostat, mirostat_eta, mirostat_tau);
|
||||
|
||||
return std::string(result);
|
||||
|
@ -188,28 +188,34 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free (params.tfs_z, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
{
|
||||
if (!params.temp_adaptive) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
} else {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_adaptive());
|
||||
}
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
|
|
|
@ -812,6 +812,7 @@ struct server_context {
|
|||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot.sparams.temp_adaptive = json_value(data, "temp_adaptive", default_sparams.temp_adaptive);
|
||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
|
@ -1142,6 +1143,7 @@ struct server_context {
|
|||
{"seed", slot.sparams.seed},
|
||||
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
|
||||
{"temperature", slot.sparams.temp},
|
||||
{"temp_adaptive", slot.sparams.temp_adaptive},
|
||||
{"dynatemp_range", slot.sparams.dynatemp_range},
|
||||
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
||||
{"top_k", slot.sparams.top_k},
|
||||
|
|
|
@ -1099,6 +1099,9 @@ extern "C" {
|
|||
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||
|
||||
/// @details Adaptive temperature implementation described in the paper https://arxiv.org/abs/2410.01104.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_temp_adaptive (void);
|
||||
|
||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||
|
||||
|
|
|
@ -1082,6 +1082,55 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|||
};
|
||||
}
|
||||
|
||||
// temp-adaptive
|
||||
|
||||
static const char * llama_sampler_temp_adaptive_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "temp-adaptive";
|
||||
}
|
||||
|
||||
static void llama_sampler_temp_adaptive_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
||||
llama_sampler_softmax_impl(cur_p);
|
||||
|
||||
// calculate entropy
|
||||
float entropy = 0.0f;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
entropy += -cur_p->data[i].p * logf(cur_p->data[i].p + 1e-9);
|
||||
}
|
||||
|
||||
// calculate beta
|
||||
float beta = 0.0f;
|
||||
if (entropy > 0.5) { // don't overcorrect low-entropy heads
|
||||
beta = -0.037 * powf(entropy, 4)
|
||||
+ 0.481 * powf(entropy, 3)
|
||||
+ -2.3 * powf(entropy, 2)
|
||||
+ 4.917 * entropy
|
||||
+ -1.791;
|
||||
// never increase entropy
|
||||
beta = (beta < 1.0) ? 1.0 : beta;
|
||||
} else {
|
||||
beta = 1.0;
|
||||
}
|
||||
|
||||
// beta = 1 / temp
|
||||
llama_sampler_temp_impl(cur_p, 1.0f / beta);
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_temp_adaptive_i = {
|
||||
/* .name = */ llama_sampler_temp_adaptive_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_adaptive_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
/* .free = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_adaptive() {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_temp_adaptive_i,
|
||||
/* .ctx = */ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
// xtc
|
||||
|
||||
struct llama_sampler_xtc {
|
||||
|
|
|
@ -72,6 +72,17 @@ static void test_temp(const std::vector<float> & probs, const std::vector<float>
|
|||
tester.check();
|
||||
}
|
||||
|
||||
static void test_temp_adaptive(const std::vector<float> & probs, const std::vector<float> & probs_expected) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
DUMP(&tester.cur_p);
|
||||
tester.apply(llama_sampler_init_temp_adaptive());
|
||||
tester.apply(llama_sampler_init_dist(0));
|
||||
DUMP(&tester.cur_p);
|
||||
|
||||
tester.check();
|
||||
}
|
||||
|
||||
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
||||
sampler_tester tester(probs, probs_expected);
|
||||
|
||||
|
@ -311,7 +322,10 @@ int main(void) {
|
|||
ggml_time_init();
|
||||
|
||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||
test_temp({0.4f, 0.3f, 0.2f, 0.1f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
||||
|
||||
test_temp_adaptive({0.1f, 0.2f, 0.3f, 0.4f}, {0.488836, 0.304651, 0.156445, 0.050068});
|
||||
test_temp_adaptive({0.7f, 0.1f, 0.1f, 0.1f}, {0.764643, 0.078452, 0.078452, 0.078452});
|
||||
|
||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
||||
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue