llama : add classifier-free guidance (#2135)
* Initial implementation * Remove debug print * Restore signature of llama_init_from_gpt_params * Free guidance context * Make freeing of guidance_ctx conditional * Make Classifier-Free Guidance a sampling function * Correct typo. CFG already means context-free grammar. * Record sampling time in llama_sample_classifier_free_guidance * Shift all values by the max value before applying logsoftmax * Fix styling based on review
This commit is contained in:
parent
3ec7e596b2
commit
c9c74b4e3f
5 changed files with 188 additions and 5 deletions
12
llama.h
12
llama.h
|
@ -309,6 +309,18 @@ extern "C" {
|
|||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
|
||||
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale,
|
||||
float smooth_factor);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue