llama : add classifier-free guidance (#2135)
* Initial implementation * Remove debug print * Restore signature of llama_init_from_gpt_params * Free guidance context * Make freeing of guidance_ctx conditional * Make Classifier-Free Guidance a sampling function * Correct typo. CFG already means context-free grammar. * Record sampling time in llama_sample_classifier_free_guidance * Shift all values by the max value before applying logsoftmax * Fix styling based on review
This commit is contained in:
		
							parent
							
								
									3ec7e596b2
								
							
						
					
					
						commit
						c9c74b4e3f
					
				
					 5 changed files with 188 additions and 5 deletions
				
			
		|  | @ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             params.mirostat_tau = std::stof(argv[i]); |             params.mirostat_tau = std::stof(argv[i]); | ||||||
|  |         } else if (arg == "--cfg-negative-prompt") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             params.cfg_negative_prompt = argv[i]; | ||||||
|  |         } else if (arg == "--cfg-scale") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             params.cfg_scale = std::stof(argv[i]); | ||||||
|  |         } else if (arg == "--cfg-smooth-factor") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             params.cfg_smooth_factor = std::stof(argv[i]); | ||||||
|         } else if (arg == "-b" || arg == "--batch-size") { |         } else if (arg == "-b" || arg == "--batch-size") { | ||||||
|             if (++i >= argc) { |             if (++i >= argc) { | ||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
|  | @ -469,6 +487,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | ||||||
|     fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n"); |     fprintf(stderr, "                        modifies the likelihood of token appearing in the completion,\n"); | ||||||
|     fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); |     fprintf(stderr, "                        i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); | ||||||
|     fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); |     fprintf(stderr, "                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); | ||||||
|  |     fprintf(stderr, "  --cfg-negative-prompt PROMPT \n"); | ||||||
|  |     fprintf(stderr, "                        negative prompt to use for guidance. (default: empty)\n"); | ||||||
|  |     fprintf(stderr, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); | ||||||
|  |     fprintf(stderr, "  --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); | ||||||
|     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx); |     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx); | ||||||
|     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); |     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); | ||||||
|     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n"); |     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n"); | ||||||
|  | @ -535,7 +557,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s | ||||||
|     return res; |     return res; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { | struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { | ||||||
|     auto lparams = llama_context_default_params(); |     auto lparams = llama_context_default_params(); | ||||||
| 
 | 
 | ||||||
|     lparams.n_ctx        = params.n_ctx; |     lparams.n_ctx        = params.n_ctx; | ||||||
|  | @ -551,6 +573,12 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par | ||||||
|     lparams.logits_all   = params.perplexity; |     lparams.logits_all   = params.perplexity; | ||||||
|     lparams.embedding    = params.embedding; |     lparams.embedding    = params.embedding; | ||||||
| 
 | 
 | ||||||
|  |     return lparams; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { | ||||||
|  |     auto lparams = llama_context_params_from_gpt_params(params); | ||||||
|  | 
 | ||||||
|     llama_model * model  = llama_load_model_from_file(params.model.c_str(), lparams); |     llama_model * model  = llama_load_model_from_file(params.model.c_str(), lparams); | ||||||
|     if (model == NULL) { |     if (model == NULL) { | ||||||
|         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); |         fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); | ||||||
|  |  | ||||||
|  | @ -48,6 +48,12 @@ struct gpt_params { | ||||||
|     float   mirostat_tau      = 5.00f; // target entropy
 |     float   mirostat_tau      = 5.00f; // target entropy
 | ||||||
|     float   mirostat_eta      = 0.10f; // learning rate
 |     float   mirostat_eta      = 0.10f; // learning rate
 | ||||||
| 
 | 
 | ||||||
|  |     // Classifier-Free Guidance
 | ||||||
|  |     // https://arxiv.org/abs/2306.17806
 | ||||||
|  |     std::string cfg_negative_prompt;       // string to help guidance
 | ||||||
|  |     float       cfg_scale         = 1.f;   // How strong is guidance
 | ||||||
|  |     float       cfg_smooth_factor = 1.f;   // Smooth factor between old and new logits
 | ||||||
|  | 
 | ||||||
|     std::string model             = "models/7B/ggml-model.bin"; // model path
 |     std::string model             = "models/7B/ggml-model.bin"; // model path
 | ||||||
|     std::string model_alias       = "unknown"; // model alias
 |     std::string model_alias       = "unknown"; // model alias
 | ||||||
|     std::string prompt            = ""; |     std::string prompt            = ""; | ||||||
|  | @ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); | std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); | ||||||
|  | struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); | ||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // Console utils
 | // Console utils
 | ||||||
|  |  | ||||||
|  | @ -109,10 +109,16 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|     llama_model * model; |     llama_model * model; | ||||||
|     llama_context * ctx; |     llama_context * ctx; | ||||||
|  |     llama_context * ctx_guidance = NULL; | ||||||
|     g_ctx = &ctx; |     g_ctx = &ctx; | ||||||
| 
 | 
 | ||||||
|     // load the model and apply lora adapter, if any
 |     // load the model and apply lora adapter, if any
 | ||||||
|     std::tie(model, ctx) = llama_init_from_gpt_params(params); |     std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||||||
|  |     if (params.cfg_scale > 1.f) { | ||||||
|  |         struct llama_context_params lparams = llama_context_params_from_gpt_params(params); | ||||||
|  |         ctx_guidance = llama_new_context_with_model(model, lparams); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     if (model == NULL) { |     if (model == NULL) { | ||||||
|         fprintf(stderr, "%s: error: unable to load model\n", __func__); |         fprintf(stderr, "%s: error: unable to load model\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|  | @ -183,15 +189,28 @@ int main(int argc, char ** argv) { | ||||||
|     // tokenize the prompt
 |     // tokenize the prompt
 | ||||||
|     std::vector<llama_token> embd_inp; |     std::vector<llama_token> embd_inp; | ||||||
| 
 | 
 | ||||||
|     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { |  | ||||||
|     // Add a space in front of the first character to match OG llama tokenizer behavior
 |     // Add a space in front of the first character to match OG llama tokenizer behavior
 | ||||||
|     params.prompt.insert(0, 1, ' '); |     params.prompt.insert(0, 1, ' '); | ||||||
| 
 | 
 | ||||||
|  |     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||||
|         embd_inp = ::llama_tokenize(ctx, params.prompt, true); |         embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
|     } else { |     } else { | ||||||
|         embd_inp = session_tokens; |         embd_inp = session_tokens; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     // Tokenize negative prompt
 | ||||||
|  |     std::vector<llama_token> guidance_inp; | ||||||
|  |     int guidance_offset = 0; | ||||||
|  |     int original_prompt_len = 0; | ||||||
|  |     if (ctx_guidance) { | ||||||
|  |         params.cfg_negative_prompt.insert(0, 1, ' '); | ||||||
|  |         guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true); | ||||||
|  | 
 | ||||||
|  |         std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
|  |         original_prompt_len = original_inp.size(); | ||||||
|  |         guidance_offset = (int)guidance_inp.size() - original_prompt_len; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     const int n_ctx = llama_n_ctx(ctx); |     const int n_ctx = llama_n_ctx(ctx); | ||||||
| 
 | 
 | ||||||
|     if ((int) embd_inp.size() > n_ctx - 4) { |     if ((int) embd_inp.size() > n_ctx - 4) { | ||||||
|  | @ -258,6 +277,16 @@ int main(int argc, char ** argv) { | ||||||
|         for (int i = 0; i < (int) embd_inp.size(); i++) { |         for (int i = 0; i < (int) embd_inp.size(); i++) { | ||||||
|             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); |             fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|  |         if (ctx_guidance) { | ||||||
|  |             fprintf(stderr, "\n"); | ||||||
|  |             fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); | ||||||
|  |             fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); | ||||||
|  |             for (int i = 0; i < (int) guidance_inp.size(); i++) { | ||||||
|  |                 fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         if (params.n_keep > 0) { |         if (params.n_keep > 0) { | ||||||
|         fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); |         fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); | ||||||
|             for (int i = 0; i < params.n_keep; i++) { |             for (int i = 0; i < params.n_keep; i++) { | ||||||
|  | @ -334,11 +363,13 @@ int main(int argc, char ** argv) { | ||||||
|     int n_remain           = params.n_predict; |     int n_remain           = params.n_predict; | ||||||
|     int n_consumed         = 0; |     int n_consumed         = 0; | ||||||
|     int n_session_consumed = 0; |     int n_session_consumed = 0; | ||||||
|  |     int n_past_guidance    = 0; | ||||||
| 
 | 
 | ||||||
|     // the first thing we will do is to output the prompt, so set color accordingly
 |     // the first thing we will do is to output the prompt, so set color accordingly
 | ||||||
|     console_set_color(con_st, CONSOLE_COLOR_PROMPT); |     console_set_color(con_st, CONSOLE_COLOR_PROMPT); | ||||||
| 
 | 
 | ||||||
|     std::vector<llama_token> embd; |     std::vector<llama_token> embd; | ||||||
|  |     std::vector<llama_token> embd_guidance; | ||||||
| 
 | 
 | ||||||
|     // do one empty run to warm up the model
 |     // do one empty run to warm up the model
 | ||||||
|     { |     { | ||||||
|  | @ -367,11 +398,12 @@ int main(int argc, char ** argv) { | ||||||
|             // if we run out of context:
 |             // if we run out of context:
 | ||||||
|             // - take the n_keep first tokens from the original prompt (via n_past)
 |             // - take the n_keep first tokens from the original prompt (via n_past)
 | ||||||
|             // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
 |             // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
 | ||||||
|             if (n_past + (int) embd.size() > n_ctx) { |             if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { | ||||||
|                 const int n_left = n_past - params.n_keep; |                 const int n_left = n_past - params.n_keep; | ||||||
| 
 | 
 | ||||||
|                 // always keep the first token - BOS
 |                 // always keep the first token - BOS
 | ||||||
|                 n_past = std::max(1, params.n_keep); |                 n_past = std::max(1, params.n_keep); | ||||||
|  |                 n_past_guidance = std::max(1, params.n_keep + guidance_offset); | ||||||
| 
 | 
 | ||||||
|                 // insert n_left/2 tokens at the start of embd from last_n_tokens
 |                 // insert n_left/2 tokens at the start of embd from last_n_tokens
 | ||||||
|                 embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); |                 embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); | ||||||
|  | @ -412,6 +444,48 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|             // evaluate tokens in batches
 |             // evaluate tokens in batches
 | ||||||
|             // embd is typically prepared beforehand to fit within a batch, but not always
 |             // embd is typically prepared beforehand to fit within a batch, but not always
 | ||||||
|  | 
 | ||||||
|  |             if (ctx_guidance) { | ||||||
|  |                 int input_size = 0; | ||||||
|  |                 llama_token* input_buf = NULL; | ||||||
|  | 
 | ||||||
|  |                 if (n_past_guidance < (int) guidance_inp.size()) { | ||||||
|  |                     // Guidance context should have the same data with these modifications:
 | ||||||
|  |                     //
 | ||||||
|  |                     // * Replace the initial prompt
 | ||||||
|  |                     // * Shift everything by guidance_offset
 | ||||||
|  |                     embd_guidance = guidance_inp; | ||||||
|  |                     if (embd.begin() + original_prompt_len < embd.end()) { | ||||||
|  |                         embd_guidance.insert( | ||||||
|  |                             embd_guidance.end(), | ||||||
|  |                             embd.begin() + original_prompt_len, | ||||||
|  |                             embd.end() | ||||||
|  |                         ); | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     input_buf = embd_guidance.data(); | ||||||
|  |                     input_size = embd_guidance.size(); | ||||||
|  |                     //fprintf(stderr, "\n---------------------\n");
 | ||||||
|  |                     //for (int i = 0; i < (int) embd_guidance.size(); i++) {
 | ||||||
|  |                         //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
 | ||||||
|  |                     //}
 | ||||||
|  |                     //fprintf(stderr, "\n---------------------\n");
 | ||||||
|  |                 } else { | ||||||
|  |                     input_buf = embd.data(); | ||||||
|  |                     input_size = embd.size(); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 for (int i = 0; i < input_size; i += params.n_batch) { | ||||||
|  |                     int n_eval = std::min(input_size - i, params.n_batch); | ||||||
|  |                     if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { | ||||||
|  |                         fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|  |                         return 1; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     n_past_guidance += n_eval; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             for (int i = 0; i < (int) embd.size(); i += params.n_batch) { |             for (int i = 0; i < (int) embd.size(); i += params.n_batch) { | ||||||
|                 int n_eval = (int) embd.size() - i; |                 int n_eval = (int) embd.size() - i; | ||||||
|                 if (n_eval > params.n_batch) { |                 if (n_eval > params.n_batch) { | ||||||
|  | @ -431,6 +505,7 @@ int main(int argc, char ** argv) { | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         embd.clear(); |         embd.clear(); | ||||||
|  |         embd_guidance.clear(); | ||||||
| 
 | 
 | ||||||
|         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { |         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||||
|             // out of user input, sample next token
 |             // out of user input, sample next token
 | ||||||
|  | @ -473,6 +548,10 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; |                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||||
| 
 | 
 | ||||||
|  |                 if (ctx_guidance) { | ||||||
|  |                     llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|                 // Apply penalties
 |                 // Apply penalties
 | ||||||
|                 float nl_logit = logits[llama_token_nl()]; |                 float nl_logit = logits[llama_token_nl()]; | ||||||
|                 auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); |                 auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||||
|  | @ -668,6 +747,7 @@ int main(int argc, char ** argv) { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     llama_print_timings(ctx); |     llama_print_timings(ctx); | ||||||
|  |     if (ctx_guidance) { llama_free(ctx_guidance); } | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|     llama_free_model(model); |     llama_free_model(model); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										56
									
								
								llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										56
									
								
								llama.cpp
									
										
									
									
									
								
							|  | @ -2167,6 +2167,62 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static void llama_log_softmax(float * array, size_t size) { | ||||||
|  |     float max_l = *std::max_element(array, array + size); | ||||||
|  |     float sum = 0.f; | ||||||
|  |     for (size_t i = 0; i < size; ++i) { | ||||||
|  |         float p = expf(array[i] - max_l); | ||||||
|  |         sum += p; | ||||||
|  |         array[i] = p; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for (size_t i = 0; i < size; ++i) { | ||||||
|  |         array[i] = logf(array[i] / sum); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void llama_sample_classifier_free_guidance( | ||||||
|  |           struct llama_context * ctx, | ||||||
|  |         llama_token_data_array * candidates, | ||||||
|  |           struct llama_context * guidance_ctx, | ||||||
|  |                          float   scale, | ||||||
|  |                          float   smooth_factor) { | ||||||
|  |     int64_t t_start_sample_us = t_start_sample_us = ggml_time_us(); | ||||||
|  | 
 | ||||||
|  |     assert(ctx); | ||||||
|  |     auto n_vocab = llama_n_vocab(ctx); | ||||||
|  |     assert(n_vocab == (int)candidates->size); | ||||||
|  |     assert(!candidates->sorted); | ||||||
|  | 
 | ||||||
|  |     std::vector<float> logits_base; | ||||||
|  |     logits_base.reserve(candidates->size); | ||||||
|  |     for (size_t i = 0; i < candidates->size; ++i) { | ||||||
|  |         logits_base.push_back(candidates->data[i].logit); | ||||||
|  |     } | ||||||
|  |     llama_log_softmax(logits_base.data(), candidates->size); | ||||||
|  | 
 | ||||||
|  |     float* logits_guidance = llama_get_logits(guidance_ctx); | ||||||
|  |     llama_log_softmax(logits_guidance, n_vocab); | ||||||
|  | 
 | ||||||
|  |     for (int i = 0; i < n_vocab; ++i) { | ||||||
|  |         float logit_guidance = logits_guidance[i]; | ||||||
|  |         float logit_base = logits_base[i]; | ||||||
|  |         logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     llama_log_softmax(logits_guidance, n_vocab); | ||||||
|  | 
 | ||||||
|  |     for (int i = 0; i < n_vocab; ++i) { | ||||||
|  |         float logit_base = logits_base[i]; | ||||||
|  |         float logit_guidance = logits_guidance[i]; | ||||||
|  | 
 | ||||||
|  |         candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (ctx) { | ||||||
|  |         ctx->t_sample_us += ggml_time_us() - t_start_sample_us; | ||||||
|  |     } | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { | llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { | ||||||
|     assert(ctx); |     assert(ctx); | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								llama.h
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								llama.h
									
										
									
									
									
								
							|  | @ -309,6 +309,18 @@ extern "C" { | ||||||
|     /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
 |     /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
 | ||||||
|     LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); |     LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); | ||||||
| 
 | 
 | ||||||
|  |     /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
 | ||||||
|  |     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
 | ||||||
|  |     /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
 | ||||||
|  |     /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
 | ||||||
|  |     /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
 | ||||||
|  |     LLAMA_API void llama_sample_classifier_free_guidance( | ||||||
|  |               struct llama_context * ctx, | ||||||
|  |             llama_token_data_array * candidates, | ||||||
|  |               struct llama_context * guidance_ctx, | ||||||
|  |                              float   scale, | ||||||
|  |                              float   smooth_factor); | ||||||
|  | 
 | ||||||
|     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 |     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 | ||||||
|     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); |     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue