llama : default sampling changes + greedy update (#9897)

* llama : deprecate softmax sampler + fix dist sampler

ggml-ci

* tests : replace macros with functions

ggml-ci

* sampling : change temperature sampler logic

For t <= 0.0f, keep the max logit intact and set the rest to -inf

* cont : no need for special "greedy" logic

top-k == 1 is the same

* tests : init prob correctly

* llama : handle temp <= 0.0 in the temp_ext sampler too

ggml-ci

* cont : avoid extra loop in temperature sampler for sub-zero temp

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-21 09:46:40 +03:00 committed by GitHub
parent bc21975084
commit 55e47786e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 202 additions and 218 deletions

View file

@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
}
*/
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
if (temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
size_t max_i = 0;
float max_l = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i ].logit > max_l) {
cur_p->data[max_i].logit = -INFINITY;
max_i = i;
max_l = cur_p->data[i].logit;
} else {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= temp;
}
}
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0);
@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
llama_sampler_softmax_impl(cur_p);
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
}
@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= ctx->temp;
}
llama_sampler_temp_impl(cur_p, ctx->temp);
}
static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
if (ctx->delta > 0) {
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
const float max_temp = ctx->temp + ctx->delta;
float exponent_val = ctx->exponent;
// no need to do anything if there is only one (or zero) candidates
@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
#endif
// Apply the dynamically calculated temperature scaling
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= dyn_temp;
}
llama_sampler_temp_impl(cur_p, dyn_temp);
// Re-compute softmax probabilities after scaling logits with dynamic temperature
const double max_l_double = cur_p->data[0].logit;
@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
}
#endif
} else {
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].logit /= ctx->temp;
}
llama_sampler_temp_impl(cur_p, ctx->temp);
}
}