llama : remove cfg smooth factor as it is only a reparameterization of the guidance scale (#2280)

This commit is contained in:
Guillaume "Vermeille" Sanchez 2023-07-21 12:58:36 +02:00 committed by GitHub
parent 73643f5fb1
commit ab0e26bdfb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 4 additions and 24 deletions

View file

@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale,
float smooth_factor) {
float scale) {
int64_t t_start_sample_us = ggml_time_us();
assert(ctx);
@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance(
for (int i = 0; i < n_vocab; ++i) {
float logit_guidance = logits_guidance[i];
float logit_base = logits_base[i];
logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
}
llama_log_softmax(logits_guidance, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
float logit_base = logits_base[i];
float logit_guidance = logits_guidance[i];
candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
}
if (ctx) {