llama : add classifier-free guidance (#2135)
* Initial implementation * Remove debug print * Restore signature of llama_init_from_gpt_params * Free guidance context * Make freeing of guidance_ctx conditional * Make Classifier-Free Guidance a sampling function * Correct typo. CFG already means context-free grammar. * Record sampling time in llama_sample_classifier_free_guidance * Shift all values by the max value before applying logsoftmax * Fix styling based on review
This commit is contained in:
parent
3ec7e596b2
commit
c9c74b4e3f
5 changed files with 188 additions and 5 deletions
|
@ -48,6 +48,12 @@ struct gpt_params {
|
|||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
|
||||
// Classifier-Free Guidance
|
||||
// https://arxiv.org/abs/2306.17806
|
||||
std::string cfg_negative_prompt; // string to help guidance
|
||||
float cfg_scale = 1.f; // How strong is guidance
|
||||
float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits
|
||||
|
||||
std::string model = "models/7B/ggml-model.bin"; // model path
|
||||
std::string model_alias = "unknown"; // model alias
|
||||
std::string prompt = "";
|
||||
|
@ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
|||
//
|
||||
|
||||
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
|
||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
|
||||
|
||||
//
|
||||
// Console utils
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue