speculative : change default p_accept to 0.5 + CLI args (#3919)
ggml-ci
This commit is contained in:
		
							parent
							
								
									05816027d6
								
							
						
					
					
						commit
						8f961abdc4
					
				
					 3 changed files with 25 additions and 5 deletions
				
			
		|  | @ -403,6 +403,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { | |||
|                 break; | ||||
|             } | ||||
|             params.n_sequences = std::stoi(argv[i]); | ||||
|         } else if (arg == "--p-accept" || arg == "-pa") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.p_accept = std::stof(argv[i]); | ||||
|         } else if (arg == "--p-split" || arg == "-ps") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.p_split = std::stof(argv[i]); | ||||
|         } else if (arg == "-m" || arg == "--model") { | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|  | @ -778,6 +790,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||
|     printf("  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); | ||||
|     printf("  -np N, --parallel N   number of parallel sequences to decode (default: %d)\n", params.n_parallel); | ||||
|     printf("  -ns N, --sequences N  number of sequences to decode (default: %d)\n", params.n_sequences); | ||||
|     printf("  -pa N, --p-accept N   speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); | ||||
|     printf("  -ps N, --p-split N    speculative decoding split probability (default: %.1f)\n", (double)params.p_split); | ||||
|     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); | ||||
|     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); | ||||
|     printf("  --image IMAGE_FILE    path to an image file. use with multimodal models\n"); | ||||
|  |  | |||
|  | @ -44,6 +44,7 @@ int32_t get_num_physical_cores(); | |||
| 
 | ||||
| struct gpt_params { | ||||
|     uint32_t seed                           = -1;    // RNG seed
 | ||||
| 
 | ||||
|     int32_t n_threads                       = get_num_physical_cores(); | ||||
|     int32_t n_threads_batch                 = -1;    // number of threads to use for batch processing (-1 = use n_threads)
 | ||||
|     int32_t n_predict                       = -1;    // new tokens to predict
 | ||||
|  | @ -54,6 +55,8 @@ struct gpt_params { | |||
|     int32_t n_chunks                        = -1;    // max number of chunks to process (-1 = unlimited)
 | ||||
|     int32_t n_parallel                      = 1;     // number of parallel sequences to decode
 | ||||
|     int32_t n_sequences                     = 1;     // number of sequences to decode
 | ||||
|     float   p_accept                        = 0.5f;  // speculative decoding accept probability
 | ||||
|     float   p_split                         = 0.1f;  // speculative decoding split probability
 | ||||
|     int32_t n_gpu_layers                    = -1;    // number of layers to store in VRAM (-1 - use default)
 | ||||
|     int32_t n_gpu_layers_draft              = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
 | ||||
|     int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
 | ||||
|  | @ -66,7 +69,8 @@ struct gpt_params { | |||
|     float   yarn_beta_fast                  = 32.0f; // YaRN low correction dim
 | ||||
|     float   yarn_beta_slow                  = 1.0f;  // YaRN high correction dim
 | ||||
|     int32_t yarn_orig_ctx                   = 0;     // YaRN original context length
 | ||||
|     int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED; | ||||
|     int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
 | ||||
|                                                                               //       pinging @cebtenzzre
 | ||||
| 
 | ||||
|     // // sampling parameters
 | ||||
|     struct llama_sampling_params sparams; | ||||
|  | @ -90,7 +94,7 @@ struct gpt_params { | |||
|     int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
 | ||||
|                                     //                                       (which is more convenient to use for plotting)
 | ||||
|                                     //
 | ||||
|     bool hellaswag         = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
 | ||||
|     bool   hellaswag       = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
 | ||||
|     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 | ||||
| 
 | ||||
|     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
 | ||||
|  |  | |||
|  | @ -37,9 +37,11 @@ int main(int argc, char ** argv) { | |||
|     // max number of parallel drafting sequences (i.e. tree branches)
 | ||||
|     const int n_seq_dft = params.n_parallel; | ||||
| 
 | ||||
|     // TODO: make this configurable
 | ||||
|     const float p_accept = 0.80f; | ||||
|     const float p_split  = 0.10f; | ||||
|     // probability threshold for accepting a token from the draft model
 | ||||
|     const float p_accept = params.p_accept; | ||||
| 
 | ||||
|     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 | ||||
|     const float p_split  = params.p_split; | ||||
| 
 | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("speculative", "log")); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue