diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8d3a3a3b6..8733d5feb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -54,20 +54,6 @@ void sigint_handler(int signo) { } #endif -void inplace_log_softmax(float* logits, int n_vocab) { - float sum = 0.f; - for (int i = 0; i < n_vocab; ++i) { - float p = expf(logits[i]); - logits[i] = p; - sum += p; - } - - for (int i = 0; i < n_vocab; ++i) { - float p = logits[i]; - logits[i] = logf(p/ sum); - } -} - int main(int argc, char ** argv) { gpt_params params; @@ -554,21 +540,6 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } - if (guidance_ctx) { - inplace_log_softmax(logits, n_vocab); - auto* guidance_logits = llama_get_logits(guidance_ctx); - inplace_log_softmax(guidance_logits, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i]; - } - inplace_log_softmax(guidance_logits, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor); - } - } - std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -577,6 +548,10 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + if (guidance_ctx) { + llama_sample_context_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); + } + // Apply penalties float nl_logit = logits[llama_token_nl()]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); diff --git a/llama.cpp b/llama.cpp index ee6ec0920..5a8c6cf3b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2141,6 +2141,61 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +template +void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { + float sum = 0.f; + for (int i = 0; i < size; ++i) { + float& logit = logit_accessor(array[i]); + float p = expf(logit); + sum += p; + logit = p; + } + + for (int i = 0; i < size; ++i) { + float& logit = logit_accessor(array[i]); + logit = logf(logit / sum); + } +} + +void llama_sample_context_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale, + float smooth_factor) { + assert(ctx); + auto n_vocab = llama_n_vocab(ctx); + assert(n_vocab == (int)candidates->size); + assert(!candidates->sorted); + + auto logit_from_token_data = [](llama_token_data& data) -> float& { + return data.logit; + }; + + auto logit_from_float = [](float& item) -> float& { + return item; + }; + + llama_log_softmax(candidates->data, candidates->size, logit_from_token_data); + + auto* guidance_logits = llama_get_logits(guidance_ctx); + llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + + for (int i = 0; i < n_vocab; ++i) { + float guidance_logit = guidance_logits[i]; + float base_logit = candidates->data[i].logit; + guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit; + } + + llama_log_softmax(guidance_logits, n_vocab, logit_from_float); + + for (int i = 0; i < n_vocab; ++i) { + float base_logit = candidates->data[i].logit; + float guidance_logit = guidance_logits[i]; + + candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; + } +} llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); diff --git a/llama.h b/llama.h index c1e7dab9f..efac46ea8 100644 --- a/llama.h +++ b/llama.h @@ -307,6 +307,18 @@ extern "C" { /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. + LLAMA_API void llama_sample_context_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale, + float smooth_factor); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);