llama : remove cfg smooth factor as it is only a reparameterization of the guidance scale (#2280)
This commit is contained in:
		
							parent
							
								
									73643f5fb1
								
							
						
					
					
						commit
						ab0e26bdfb
					
				
					 5 changed files with 4 additions and 24 deletions
				
			
		|  | @ -260,12 +260,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|                 break; | ||||
|             } | ||||
|             params.cfg_scale = std::stof(argv[i]); | ||||
|         } else if (arg == "--cfg-smooth-factor") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.cfg_smooth_factor = std::stof(argv[i]); | ||||
|         } else if (arg == "-b" || arg == "--batch-size") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|  | @ -509,7 +503,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||
|     fprintf(stderr, "  --cfg-negative-prompt PROMPT \n"); | ||||
|     fprintf(stderr, "                        negative prompt to use for guidance. (default: empty)\n"); | ||||
|     fprintf(stderr, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); | ||||
|     fprintf(stderr, "  --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); | ||||
|     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx); | ||||
|     fprintf(stderr, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base); | ||||
|     fprintf(stderr, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); | ||||
|  |  | |||
|  | @ -55,7 +55,6 @@ struct gpt_params { | |||
|     // https://arxiv.org/abs/2306.17806
 | ||||
|     std::string cfg_negative_prompt;       // string to help guidance
 | ||||
|     float       cfg_scale         = 1.f;   // How strong is guidance
 | ||||
|     float       cfg_smooth_factor = 1.f;   // Smooth factor between old and new logits
 | ||||
| 
 | ||||
|     std::string model             = "models/7B/ggml-model.bin"; // model path
 | ||||
|     std::string model_alias       = "unknown"; // model alias
 | ||||
|  |  | |||
|  | @ -557,7 +557,7 @@ int main(int argc, char ** argv) { | |||
|                 llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||
| 
 | ||||
|                 if (ctx_guidance) { | ||||
|                     llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); | ||||
|                     llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale); | ||||
|                 } | ||||
| 
 | ||||
|                 // Apply penalties
 | ||||
|  |  | |||
							
								
								
									
										14
									
								
								llama.cpp
									
										
									
									
									
								
							
							
						
						
									
										14
									
								
								llama.cpp
									
										
									
									
									
								
							|  | @ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance( | |||
|           struct llama_context * ctx, | ||||
|         llama_token_data_array * candidates, | ||||
|           struct llama_context * guidance_ctx, | ||||
|                          float   scale, | ||||
|                          float   smooth_factor) { | ||||
|                          float   scale) { | ||||
|     int64_t t_start_sample_us = ggml_time_us(); | ||||
| 
 | ||||
|     assert(ctx); | ||||
|  | @ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance( | |||
|     for (int i = 0; i < n_vocab; ++i) { | ||||
|         float logit_guidance = logits_guidance[i]; | ||||
|         float logit_base = logits_base[i]; | ||||
|         logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; | ||||
|     } | ||||
| 
 | ||||
|     llama_log_softmax(logits_guidance, n_vocab); | ||||
| 
 | ||||
|     for (int i = 0; i < n_vocab; ++i) { | ||||
|         float logit_base = logits_base[i]; | ||||
|         float logit_guidance = logits_guidance[i]; | ||||
| 
 | ||||
|         candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; | ||||
|         candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; | ||||
|     } | ||||
| 
 | ||||
|     if (ctx) { | ||||
|  |  | |||
							
								
								
									
										4
									
								
								llama.h
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								llama.h
									
										
									
									
									
								
							|  | @ -344,13 +344,11 @@ extern "C" { | |||
|     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
 | ||||
|     /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
 | ||||
|     /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
 | ||||
|     /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
 | ||||
|     LLAMA_API void llama_sample_classifier_free_guidance( | ||||
|               struct llama_context * ctx, | ||||
|             llama_token_data_array * candidates, | ||||
|               struct llama_context * guidance_ctx, | ||||
|                              float   scale, | ||||
|                              float   smooth_factor); | ||||
|                              float   scale); | ||||
| 
 | ||||
|     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 | ||||
|     LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue