speculative : refactor and add a simpler example (#10362)
* speculative : refactor and add a simpler example ggml-ci * speculative : clean-up and add comments and TODOs [no ci] * speculative : manage context in common_speculative ggml-ci * speculative : simplify ggml-ci * speculative : simplify (cont) ggml-ci * speculative : add --draft-min CLI arg * speculative : minor fixup * make : build fixes * speculative : do not redraft previous drafts ggml-ci * speculative : fix the draft sampling ggml-ci * speculative : fix compile warning * common : refactor args ggml-ci * common : change defaults [no ci] * common : final touches ggml-ci
This commit is contained in:
		
							parent
							
								
									cce5a90075
								
							
						
					
					
						commit
						d9d54e498d
					
				
					 28 changed files with 1028 additions and 326 deletions
				
			
		
							
								
								
									
										1
									
								
								Makefile
									
										
									
									
									
								
							
							
						
						
									
										1
									
								
								Makefile
									
										
									
									
									
								
							|  | @ -966,6 +966,7 @@ OBJ_COMMON = \ | |||
| 	$(DIR_COMMON)/console.o \
 | ||||
| 	$(DIR_COMMON)/ngram-cache.o \
 | ||||
| 	$(DIR_COMMON)/sampling.o \
 | ||||
| 	$(DIR_COMMON)/speculative.o \
 | ||||
| 	$(DIR_COMMON)/build-info.o \
 | ||||
| 	$(DIR_COMMON)/json-schema-to-grammar.o | ||||
| 
 | ||||
|  |  | |||
|  | @ -66,6 +66,8 @@ add_library(${TARGET} STATIC | |||
|     ngram-cache.h | ||||
|     sampling.cpp | ||||
|     sampling.h | ||||
|     speculative.cpp | ||||
|     speculative.h | ||||
|     ) | ||||
| 
 | ||||
| if (BUILD_SHARED_LIBS) | ||||
|  |  | |||
							
								
								
									
										436
									
								
								common/arg.cpp
									
										
									
									
									
								
							
							
						
						
									
										436
									
								
								common/arg.cpp
									
										
									
									
									
								
							|  | @ -235,8 +235,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context | |||
| 
 | ||||
|     postprocess_cpu_params(params.cpuparams,       nullptr); | ||||
|     postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); | ||||
|     postprocess_cpu_params(params.draft_cpuparams, ¶ms.cpuparams); | ||||
|     postprocess_cpu_params(params.draft_cpuparams_batch, ¶ms.cpuparams_batch); | ||||
| 
 | ||||
|     postprocess_cpu_params(params.speculative.cpuparams,       ¶ms.cpuparams); | ||||
|     postprocess_cpu_params(params.speculative.cpuparams_batch, ¶ms.cpuparams_batch); | ||||
| 
 | ||||
|     if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { | ||||
|         throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); | ||||
|  | @ -251,7 +252,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context | |||
|         for (auto & antiprompt : params.antiprompt) { | ||||
|             string_process_escapes(antiprompt); | ||||
|         } | ||||
|         for (auto & seq_breaker : params.sparams.dry_sequence_breakers) { | ||||
|         for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { | ||||
|             string_process_escapes(seq_breaker); | ||||
|         } | ||||
|     } | ||||
|  | @ -329,7 +330,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
| 
 | ||||
|     std::string sampler_type_chars; | ||||
|     std::string sampler_type_names; | ||||
|     for (const auto & sampler : params.sparams.samplers) { | ||||
|     for (const auto & sampler : params.sampling.samplers) { | ||||
|         sampler_type_chars += common_sampler_type_to_chr(sampler); | ||||
|         sampler_type_names += common_sampler_type_to_str(sampler) + ";"; | ||||
|     } | ||||
|  | @ -407,26 +408,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             } | ||||
|         } | ||||
|     )); | ||||
|     add_opt(common_arg( | ||||
|         {"-td", "--threads-draft"}, "N", | ||||
|         "number of threads to use during generation (default: same as --threads)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams.n_threads = value; | ||||
|             if (params.draft_cpuparams.n_threads <= 0) { | ||||
|                 params.draft_cpuparams.n_threads = std::thread::hardware_concurrency(); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-tbd", "--threads-batch-draft"}, "N", | ||||
|         "number of threads to use during batch and prompt processing (default: same as --threads-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams_batch.n_threads = value; | ||||
|             if (params.draft_cpuparams_batch.n_threads <= 0) { | ||||
|                 params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency(); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-C", "--cpu-mask"}, "M", | ||||
|         "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", | ||||
|  | @ -515,108 +496,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             params.cpuparams_batch.poll = value; | ||||
|         } | ||||
|     )); | ||||
|     add_opt(common_arg( | ||||
|         {"-Cd", "--cpu-mask-draft"}, "M", | ||||
|         "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", | ||||
|         [](common_params & params, const std::string & mask) { | ||||
|             params.draft_cpuparams.mask_valid = true; | ||||
|             if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Crd", "--cpu-range-draft"}, "lo-hi", | ||||
|         "Ranges of CPUs for affinity. Complements --cpu-mask-draft", | ||||
|         [](common_params & params, const std::string & range) { | ||||
|             params.draft_cpuparams.mask_valid = true; | ||||
|             if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid range"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--cpu-strict-draft"}, "<0|1>", | ||||
|         "Use strict CPU placement for draft model (default: same as --cpu-strict)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams.strict_cpu = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--prio-draft"}, "N", | ||||
|         string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority), | ||||
|         [](common_params & params, int prio) { | ||||
|             if (prio < 0 || prio > 3) { | ||||
|                 throw std::invalid_argument("invalid value"); | ||||
|             } | ||||
|             params.draft_cpuparams.priority = (enum ggml_sched_priority) prio; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--poll-draft"}, "<0|1>", | ||||
|         "Use polling to wait for draft model work (default: same as --poll])", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams.poll = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Cbd", "--cpu-mask-batch-draft"}, "M", | ||||
|         "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", | ||||
|         [](common_params & params, const std::string & mask) { | ||||
|             params.draft_cpuparams_batch.mask_valid = true; | ||||
|             if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", | ||||
|         "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", | ||||
|         [](common_params & params, const std::string & range) { | ||||
|             params.draft_cpuparams_batch.mask_valid = true; | ||||
|             if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--cpu-strict-batch-draft"}, "<0|1>", | ||||
|         "Use strict CPU placement for draft model (default: --cpu-strict-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams_batch.strict_cpu = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--prio-batch-draft"}, "N", | ||||
|         string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority), | ||||
|         [](common_params & params, int prio) { | ||||
|             if (prio < 0 || prio > 3) { | ||||
|                 throw std::invalid_argument("invalid value"); | ||||
|             } | ||||
|             params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--poll-batch-draft"}, "<0|1>", | ||||
|         "Use polling to wait for draft model work (default: --poll-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.draft_cpuparams_batch.poll = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--draft"}, "N", | ||||
|         string_format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft), | ||||
|         [](common_params & params, int value) { | ||||
|             params.n_draft = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); | ||||
|     add_opt(common_arg( | ||||
|         {"-ps", "--p-split"}, "N", | ||||
|         string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.p_split = std::stof(value); | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-lcs", "--lookup-cache-static"}, "FNAME", | ||||
|         "path to static lookup cache to use for lookup decoding (not updated by generation)", | ||||
|  | @ -701,7 +580,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|         string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), | ||||
|         [](common_params & params) { | ||||
|             params.no_perf = true; | ||||
|             params.sparams.no_perf = true; | ||||
|             params.sampling.no_perf = true; | ||||
|         } | ||||
|     ).set_env("LLAMA_ARG_NO_PERF")); | ||||
|     add_opt(common_arg( | ||||
|  | @ -883,155 +762,155 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|         string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             const auto sampler_names = string_split<std::string>(value, ';'); | ||||
|             params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); | ||||
|             params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"-s", "--seed"}, "SEED", | ||||
|         string_format("RNG seed (default: %d, use random seed for %d)", params.sparams.seed, LLAMA_DEFAULT_SEED), | ||||
|         string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.seed = std::stoul(value); | ||||
|             params.sampling.seed = std::stoul(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--sampling-seq"}, "SEQUENCE", | ||||
|         string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.samplers = common_sampler_types_from_chars(value); | ||||
|             params.sampling.samplers = common_sampler_types_from_chars(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--ignore-eos"}, | ||||
|         "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", | ||||
|         [](common_params & params) { | ||||
|             params.sparams.ignore_eos = true; | ||||
|             params.sampling.ignore_eos = true; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--penalize-nl"}, | ||||
|         string_format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"), | ||||
|         string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"), | ||||
|         [](common_params & params) { | ||||
|             params.sparams.penalize_nl = true; | ||||
|             params.sampling.penalize_nl = true; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--temp"}, "N", | ||||
|         string_format("temperature (default: %.1f)", (double)params.sparams.temp), | ||||
|         string_format("temperature (default: %.1f)", (double)params.sampling.temp), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.temp = std::stof(value); | ||||
|             params.sparams.temp = std::max(params.sparams.temp, 0.0f); | ||||
|             params.sampling.temp = std::stof(value); | ||||
|             params.sampling.temp = std::max(params.sampling.temp, 0.0f); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--top-k"}, "N", | ||||
|         string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), | ||||
|         string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), | ||||
|         [](common_params & params, int value) { | ||||
|             params.sparams.top_k = value; | ||||
|             params.sampling.top_k = value; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--top-p"}, "N", | ||||
|         string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p), | ||||
|         string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.top_p = std::stof(value); | ||||
|             params.sampling.top_p = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--min-p"}, "N", | ||||
|         string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p), | ||||
|         string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.min_p = std::stof(value); | ||||
|             params.sampling.min_p = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--xtc-probability"}, "N", | ||||
|         string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability), | ||||
|         string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.xtc_probability = std::stof(value); | ||||
|             params.sampling.xtc_probability = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--xtc-threshold"}, "N", | ||||
|         string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold), | ||||
|         string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.xtc_threshold = std::stof(value); | ||||
|             params.sampling.xtc_threshold = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--typical"}, "N", | ||||
|         string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), | ||||
|         string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.typ_p = std::stof(value); | ||||
|             params.sampling.typ_p = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--repeat-last-n"}, "N", | ||||
|         string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n), | ||||
|         string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n), | ||||
|         [](common_params & params, int value) { | ||||
|             params.sparams.penalty_last_n = value; | ||||
|             params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n); | ||||
|             params.sampling.penalty_last_n = value; | ||||
|             params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--repeat-penalty"}, "N", | ||||
|         string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat), | ||||
|         string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.penalty_repeat = std::stof(value); | ||||
|             params.sampling.penalty_repeat = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--presence-penalty"}, "N", | ||||
|         string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present), | ||||
|         string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.penalty_present = std::stof(value); | ||||
|             params.sampling.penalty_present = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--frequency-penalty"}, "N", | ||||
|         string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq), | ||||
|         string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.penalty_freq = std::stof(value); | ||||
|             params.sampling.penalty_freq = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dry-multiplier"}, "N", | ||||
|         string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier), | ||||
|         string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.dry_multiplier = std::stof(value); | ||||
|             params.sampling.dry_multiplier = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dry-base"}, "N", | ||||
|         string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base), | ||||
|         string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             float potential_base = std::stof(value); | ||||
|             if (potential_base >= 1.0f) | ||||
|             { | ||||
|                 params.sparams.dry_base = potential_base; | ||||
|                 params.sampling.dry_base = potential_base; | ||||
|             } | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dry-allowed-length"}, "N", | ||||
|         string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length), | ||||
|         string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length), | ||||
|         [](common_params & params, int value) { | ||||
|             params.sparams.dry_allowed_length = value; | ||||
|             params.sampling.dry_allowed_length = value; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dry-penalty-last-n"}, "N", | ||||
|         string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n), | ||||
|         string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n), | ||||
|         [](common_params & params, int value) { | ||||
|             params.sparams.dry_penalty_last_n = value; | ||||
|             params.sampling.dry_penalty_last_n = value; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dry-sequence-breaker"}, "STRING", | ||||
|         string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n", | ||||
|             params.sparams.dry_sequence_breakers.empty() ? "none" : | ||||
|             std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()), | ||||
|                 params.sparams.dry_sequence_breakers.end(), | ||||
|                 std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'", | ||||
|             params.sampling.dry_sequence_breakers.empty() ? "none" : | ||||
|             std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()), | ||||
|                 params.sampling.dry_sequence_breakers.end(), | ||||
|                 std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'", | ||||
|                 [](const std::string& a, const std::string& b) { | ||||
|                     std::string formatted_b = (b == "\n") ? "\\n" : b; | ||||
|                     return a + ", '" + formatted_b + "'"; | ||||
|  | @ -1040,51 +919,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             static bool defaults_cleared = false; | ||||
| 
 | ||||
|             if (!defaults_cleared) { | ||||
|                 params.sparams.dry_sequence_breakers.clear(); | ||||
|                 params.sampling.dry_sequence_breakers.clear(); | ||||
|                 defaults_cleared = true; | ||||
|             } | ||||
| 
 | ||||
|             if (value == "none") { | ||||
|                 params.sparams.dry_sequence_breakers.clear(); | ||||
|                 params.sampling.dry_sequence_breakers.clear(); | ||||
|             } else { | ||||
|                 params.sparams.dry_sequence_breakers.emplace_back(value); | ||||
|                 params.sampling.dry_sequence_breakers.emplace_back(value); | ||||
|             } | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dynatemp-range"}, "N", | ||||
|         string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), | ||||
|         string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.dynatemp_range = std::stof(value); | ||||
|             params.sampling.dynatemp_range = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--dynatemp-exp"}, "N", | ||||
|         string_format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent), | ||||
|         string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.dynatemp_exponent = std::stof(value); | ||||
|             params.sampling.dynatemp_exponent = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--mirostat"}, "N", | ||||
|         string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n" | ||||
|         "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat), | ||||
|         "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), | ||||
|         [](common_params & params, int value) { | ||||
|             params.sparams.mirostat = value; | ||||
|             params.sampling.mirostat = value; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--mirostat-lr"}, "N", | ||||
|         string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta), | ||||
|         string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.mirostat_eta = std::stof(value); | ||||
|             params.sampling.mirostat_eta = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--mirostat-ent"}, "N", | ||||
|         string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau), | ||||
|         string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.mirostat_tau = std::stof(value); | ||||
|             params.sampling.mirostat_tau = std::stof(value); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|  | @ -1100,7 +979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             try { | ||||
|                 if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { | ||||
|                     const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); | ||||
|                     params.sparams.logit_bias.push_back({key, bias}); | ||||
|                     params.sampling.logit_bias.push_back({key, bias}); | ||||
|                 } else { | ||||
|                     throw std::invalid_argument("invalid input format"); | ||||
|                 } | ||||
|  | @ -1111,9 +990,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|         {"--grammar"}, "GRAMMAR", | ||||
|         string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()), | ||||
|         string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.grammar = value; | ||||
|             params.sampling.grammar = value; | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|  | @ -1127,7 +1006,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             std::copy( | ||||
|                 std::istreambuf_iterator<char>(file), | ||||
|                 std::istreambuf_iterator<char>(), | ||||
|                 std::back_inserter(params.sparams.grammar) | ||||
|                 std::back_inserter(params.sampling.grammar) | ||||
|             ); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|  | @ -1135,7 +1014,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|         {"-j", "--json-schema"}, "SCHEMA", | ||||
|         "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.sparams.grammar = json_schema_to_grammar(json::parse(value)); | ||||
|             params.sampling.grammar = json_schema_to_grammar(json::parse(value)); | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(common_arg( | ||||
|  | @ -1444,17 +1323,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             } | ||||
|         } | ||||
|     ).set_env("LLAMA_ARG_N_GPU_LAYERS")); | ||||
|     add_opt(common_arg( | ||||
|         {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", | ||||
|         "number of layers to store in VRAM for the draft model", | ||||
|         [](common_params & params, int value) { | ||||
|             params.n_gpu_layers_draft = value; | ||||
|             if (!llama_supports_gpu_offload()) { | ||||
|                 fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); | ||||
|                 fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-sm", "--split-mode"}, "{none,layer,row}", | ||||
|         "how to split the model across multiple GPUs, one of:\n" | ||||
|  | @ -1593,13 +1461,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|             params.model = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); | ||||
|     add_opt(common_arg( | ||||
|         {"-md", "--model-draft"}, "FNAME", | ||||
|         "draft model for speculative decoding (default: unused)", | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.model_draft = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-mu", "--model-url"}, "MODEL_URL", | ||||
|         "model download url (default: unused)", | ||||
|  | @ -2037,5 +1898,168 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||
|         } | ||||
|     ).set_env("LLAMA_LOG_TIMESTAMPS")); | ||||
| 
 | ||||
|     // speculative parameters
 | ||||
|     add_opt(common_arg( | ||||
|         {"-td", "--threads-draft"}, "N", | ||||
|         "number of threads to use during generation (default: same as --threads)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams.n_threads = value; | ||||
|             if (params.speculative.cpuparams.n_threads <= 0) { | ||||
|                 params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency(); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-tbd", "--threads-batch-draft"}, "N", | ||||
|         "number of threads to use during batch and prompt processing (default: same as --threads-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams_batch.n_threads = value; | ||||
|             if (params.speculative.cpuparams_batch.n_threads <= 0) { | ||||
|                 params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Cd", "--cpu-mask-draft"}, "M", | ||||
|         "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", | ||||
|         [](common_params & params, const std::string & mask) { | ||||
|             params.speculative.cpuparams.mask_valid = true; | ||||
|             if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Crd", "--cpu-range-draft"}, "lo-hi", | ||||
|         "Ranges of CPUs for affinity. Complements --cpu-mask-draft", | ||||
|         [](common_params & params, const std::string & range) { | ||||
|             params.speculative.cpuparams.mask_valid = true; | ||||
|             if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid range"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--cpu-strict-draft"}, "<0|1>", | ||||
|         "Use strict CPU placement for draft model (default: same as --cpu-strict)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams.strict_cpu = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--prio-draft"}, "N", | ||||
|         string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority), | ||||
|         [](common_params & params, int prio) { | ||||
|             if (prio < 0 || prio > 3) { | ||||
|                 throw std::invalid_argument("invalid value"); | ||||
|             } | ||||
|             params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--poll-draft"}, "<0|1>", | ||||
|         "Use polling to wait for draft model work (default: same as --poll])", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams.poll = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Cbd", "--cpu-mask-batch-draft"}, "M", | ||||
|         "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", | ||||
|         [](common_params & params, const std::string & mask) { | ||||
|             params.speculative.cpuparams_batch.mask_valid = true; | ||||
|             if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", | ||||
|         "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", | ||||
|         [](common_params & params, const std::string & range) { | ||||
|             params.speculative.cpuparams_batch.mask_valid = true; | ||||
|             if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) { | ||||
|                 throw std::invalid_argument("invalid cpumask"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--cpu-strict-batch-draft"}, "<0|1>", | ||||
|         "Use strict CPU placement for draft model (default: --cpu-strict-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams_batch.strict_cpu = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--prio-batch-draft"}, "N", | ||||
|         string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority), | ||||
|         [](common_params & params, int prio) { | ||||
|             if (prio < 0 || prio > 3) { | ||||
|                 throw std::invalid_argument("invalid value"); | ||||
|             } | ||||
|             params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--poll-batch-draft"}, "<0|1>", | ||||
|         "Use polling to wait for draft model work (default: --poll-draft)", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.cpuparams_batch.poll = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--draft-max", "--draft", "--draft-n"}, "N", | ||||
|         string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.n_max = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"--draft-min", "--draft-n-min"}, "N", | ||||
|         string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min), | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.n_min = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"--draft-p-split"}, "P", | ||||
|         string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.speculative.p_split = std::stof(value); | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); | ||||
|     add_opt(common_arg( | ||||
|         {"--draft-p-min"}, "P", | ||||
|         string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.speculative.p_min = std::stof(value); | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"-cd", "--ctx-size-draft"}, "N", | ||||
|         string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.n_ctx = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", | ||||
|         "number of layers to store in VRAM for the draft model", | ||||
|         [](common_params & params, int value) { | ||||
|             params.speculative.n_gpu_layers = value; | ||||
|             if (!llama_supports_gpu_offload()) { | ||||
|                 fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); | ||||
|                 fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); | ||||
|             } | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); | ||||
|     add_opt(common_arg( | ||||
|         {"-md", "--model-draft"}, "FNAME", | ||||
|         "draft model for speculative decoding (default: unused)", | ||||
|         [](common_params & params, const std::string & value) { | ||||
|             params.speculative.model = value; | ||||
|         } | ||||
|     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); | ||||
| 
 | ||||
|     return ctx_arg; | ||||
| } | ||||
|  |  | |||
|  | @ -537,11 +537,11 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat | |||
|                 detokenized.end()); | ||||
| 
 | ||||
|         buf << "\n"          << std::to_string(i) | ||||
|             << ":token '" << detokenized << "'" | ||||
|             << ":pos " << std::to_string(batch.pos[i]) | ||||
|             << ":n_seq_id  " << std::to_string(batch.n_seq_id[i]) | ||||
|             << ":seq_id " << std::to_string(batch.seq_id[i][0]) | ||||
|             << ":logits " << std::to_string(batch.logits[i]); | ||||
|             << ", token '"   << detokenized << "'" | ||||
|             << ", pos "      << std::to_string(batch.pos[i]) | ||||
|             << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) | ||||
|             << ", seq_id "   << std::to_string(batch.seq_id[i][0]) | ||||
|             << ", logits "   << std::to_string(batch.logits[i]); | ||||
|     } | ||||
| 
 | ||||
|     buf << " ]"; | ||||
|  | @ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) { | |||
|         common_lora_adapters_apply(lctx, iparams.lora_adapters); | ||||
|     } | ||||
| 
 | ||||
|     if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { | ||||
|     if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { | ||||
|         LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); | ||||
|         params.sparams.ignore_eos = false; | ||||
|         params.sampling.ignore_eos = false; | ||||
|     } | ||||
| 
 | ||||
|     if (params.warmup) { | ||||
|  | @ -1490,6 +1490,66 @@ void common_batch_add( | |||
|     batch.n_tokens++; | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // Token utils
 | ||||
| //
 | ||||
| 
 | ||||
| size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { | ||||
|     size_t i; | ||||
|     for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} | ||||
| 
 | ||||
|     return i; | ||||
| } | ||||
| 
 | ||||
| size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { | ||||
|     // check for empty sequences
 | ||||
|     if (a.empty() || b.empty()) { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     // get the lengths of the input sequences
 | ||||
|     size_t a_len = a.size(); | ||||
|     size_t b_len = b.size(); | ||||
| 
 | ||||
|     // initialize the maximum length of the longest common subsequence (LCS)
 | ||||
|     size_t max_length = 0; | ||||
| 
 | ||||
|     // use two rows instead of a 2D matrix to optimize space
 | ||||
|     std::vector<size_t> prev_row(b_len + 1, 0); | ||||
|     std::vector<size_t> curr_row(b_len + 1, 0); | ||||
| 
 | ||||
|     // iterate through the elements of a
 | ||||
|     for (size_t i = 1; i <= a_len; i++) { | ||||
|         // iterate through the elements of b
 | ||||
|         for (size_t j = 1; j <= b_len; j++) { | ||||
|             // if elements at the current positions match
 | ||||
|             if (a[i - 1] == b[j - 1]) { | ||||
|                 // if it's the first element of either sequences, set LCS length to 1
 | ||||
|                 if (i == 1 || j == 1) { | ||||
|                     curr_row[j] = 1; | ||||
|                 } else { | ||||
|                     // increment LCS length by 1 compared to the previous element
 | ||||
|                     curr_row[j] = prev_row[j - 1] + 1; | ||||
|                 } | ||||
| 
 | ||||
|                 // update max_length if necessary
 | ||||
|                 if (curr_row[j] > max_length) { | ||||
|                     max_length = curr_row[j]; | ||||
|                 } | ||||
|             } else { | ||||
|                 // reset LCS length if elements don't match
 | ||||
|                 curr_row[j] = 0; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // update the previous row for the next iteration
 | ||||
|         prev_row = curr_row; | ||||
|     } | ||||
| 
 | ||||
|     // return the maximum length of the LCS
 | ||||
|     return max_length; | ||||
| } | ||||
| 
 | ||||
| //
 | ||||
| // Vocab utils
 | ||||
| //
 | ||||
|  |  | |||
|  | @ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info { | |||
|     struct llama_lora_adapter * adapter; | ||||
| }; | ||||
| 
 | ||||
| using llama_tokens = std::vector<llama_token>; | ||||
| 
 | ||||
| // build info
 | ||||
| extern int LLAMA_BUILD_NUMBER; | ||||
| extern char const * LLAMA_COMMIT; | ||||
|  | @ -101,8 +103,8 @@ enum dimre_method { | |||
|     DIMRE_METHOD_MEAN, | ||||
| }; | ||||
| 
 | ||||
| // sampler parameters
 | ||||
| struct common_sampler_params { | ||||
| // sampling parameters
 | ||||
| struct common_params_sampling { | ||||
|     uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
 | ||||
| 
 | ||||
|     int32_t n_prev             = 64;    // number of previous tokens to remember
 | ||||
|  | @ -153,19 +155,30 @@ struct common_sampler_params { | |||
|     std::string print() const; | ||||
| }; | ||||
| 
 | ||||
| struct common_params_speculative { | ||||
|     int32_t n_ctx        =     0; // draft context size
 | ||||
|     int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
 | ||||
|     int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
 | ||||
|     int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
 | ||||
|     float   p_split      =  0.1f; // speculative decoding split probability
 | ||||
|     float   p_min        =  0.9f; // minimum speculative decoding probability (greedy)
 | ||||
| 
 | ||||
|     struct cpu_params cpuparams; | ||||
|     struct cpu_params cpuparams_batch; | ||||
| 
 | ||||
|     std::string model = ""; // draft model for speculative decoding                          // NOLINT
 | ||||
| }; | ||||
| 
 | ||||
| struct common_params { | ||||
|     int32_t n_predict             =    -1; // new tokens to predict
 | ||||
|     int32_t n_ctx                 =  4096; // context size
 | ||||
|     int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
 | ||||
|     int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
 | ||||
|     int32_t n_keep                =     0; // number of tokens to keep from initial prompt
 | ||||
|     int32_t n_draft               =     5; // number of tokens to draft during speculative decoding
 | ||||
|     int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
 | ||||
|     int32_t n_parallel            =     1; // number of parallel sequences to decode
 | ||||
|     int32_t n_sequences           =     1; // number of sequences to decode
 | ||||
|     float   p_split               =  0.1f; // speculative decoding split probability
 | ||||
|     int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
 | ||||
|     int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
 | ||||
|     int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
 | ||||
|     float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
 | ||||
|     int32_t grp_attn_n            =     1; // group-attention factor
 | ||||
|  | @ -182,8 +195,6 @@ struct common_params { | |||
| 
 | ||||
|     struct cpu_params cpuparams; | ||||
|     struct cpu_params cpuparams_batch; | ||||
|     struct cpu_params draft_cpuparams; | ||||
|     struct cpu_params draft_cpuparams_batch; | ||||
| 
 | ||||
|     ggml_backend_sched_eval_callback cb_eval = nullptr; | ||||
|     void * cb_eval_user_data                 = nullptr; | ||||
|  | @ -195,10 +206,10 @@ struct common_params { | |||
|     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
 | ||||
|     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 | ||||
| 
 | ||||
|     struct common_sampler_params sparams; | ||||
|     struct common_params_sampling sampling; | ||||
|     struct common_params_speculative speculative; | ||||
| 
 | ||||
|     std::string model                = ""; // model path                                                    // NOLINT
 | ||||
|     std::string model_draft          = ""; // draft model for speculative decoding                          // NOLINT
 | ||||
|     std::string model_alias          = "unknown"; // model alias                                            // NOLINT
 | ||||
|     std::string model_url            = ""; // model url to download                                         // NOLINT
 | ||||
|     std::string hf_token             = ""; // HF token                                                      // NOLINT
 | ||||
|  | @ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f | |||
| // clear LoRA adapters from context, then apply new list of adapters
 | ||||
| void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters); | ||||
| 
 | ||||
| //
 | ||||
| // Batch utils
 | ||||
| //
 | ||||
| 
 | ||||
| void common_batch_clear(struct llama_batch & batch); | ||||
| 
 | ||||
|  | @ -472,6 +485,16 @@ void common_batch_add( | |||
|     const std::vector<llama_seq_id> & seq_ids, | ||||
|                                bool   logits); | ||||
| 
 | ||||
| //
 | ||||
| // Token utils
 | ||||
| //
 | ||||
| 
 | ||||
| // longest common prefix
 | ||||
| size_t common_lcp(const llama_tokens & a, const llama_tokens & b); | ||||
| 
 | ||||
| // longet common subsequence
 | ||||
| size_t common_lcs(const llama_tokens & a, const llama_tokens & b); | ||||
| 
 | ||||
| //
 | ||||
| // Vocab utils
 | ||||
| //
 | ||||
|  |  | |||
|  | @ -99,7 +99,7 @@ struct ring_buffer { | |||
| }; | ||||
| 
 | ||||
| struct common_sampler { | ||||
|     common_sampler_params params; | ||||
|     common_params_sampling params; | ||||
| 
 | ||||
|     struct llama_sampler * grmr; | ||||
|     struct llama_sampler * chain; | ||||
|  | @ -125,7 +125,7 @@ struct common_sampler { | |||
|     } | ||||
| }; | ||||
| 
 | ||||
| std::string common_sampler_params::print() const { | ||||
| std::string common_params_sampling::print() const { | ||||
|     char result[1024]; | ||||
| 
 | ||||
|     snprintf(result, sizeof(result), | ||||
|  | @ -141,7 +141,7 @@ std::string common_sampler_params::print() const { | |||
|     return std::string(result); | ||||
| } | ||||
| 
 | ||||
| struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) { | ||||
| struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { | ||||
|     llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); | ||||
| 
 | ||||
|     lparams.no_perf = params.no_perf; | ||||
|  | @ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co | |||
|     return cur_p.data[cur_p.selected].id; | ||||
| } | ||||
| 
 | ||||
| std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) { | ||||
|     GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); | ||||
| 
 | ||||
|     std::vector<llama_token> result; | ||||
|     result.reserve(idxs.size()); | ||||
| 
 | ||||
|     size_t i = 0; | ||||
|     for (; i < draft.size(); i++) { | ||||
|         const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); | ||||
| 
 | ||||
|         common_sampler_accept(gsmpl, id, true); | ||||
| 
 | ||||
|         result.push_back(id); | ||||
| 
 | ||||
|         if (draft[i] != id) { | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (i == draft.size()) { | ||||
|         const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); | ||||
| 
 | ||||
|         common_sampler_accept(gsmpl, id, true); | ||||
| 
 | ||||
|         result.push_back(id); | ||||
|     } | ||||
| 
 | ||||
|     return result; | ||||
| } | ||||
| 
 | ||||
| std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { | ||||
|     std::vector<int> idxs(draft.size() + 1); | ||||
|     for (size_t i = 0; i < idxs.size(); ++i) { | ||||
|         idxs[i] = i; | ||||
|     } | ||||
| 
 | ||||
|     return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); | ||||
| } | ||||
| 
 | ||||
| uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { | ||||
|     return llama_sampler_get_seed(gsmpl->chain); | ||||
| } | ||||
|  |  | |||
|  | @ -36,7 +36,7 @@ struct common_sampler; | |||
| 
 | ||||
| // llama_sampler API overloads
 | ||||
| 
 | ||||
| struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params); | ||||
| struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); | ||||
| 
 | ||||
| void common_sampler_free(struct common_sampler * gsmpl); | ||||
| 
 | ||||
|  | @ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam | |||
| //
 | ||||
| llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); | ||||
| 
 | ||||
| // generalized version of common_sampler_sample
 | ||||
| //
 | ||||
| // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
 | ||||
| // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
 | ||||
| //
 | ||||
| //      common_sampler_sample_n(gsmpl, ctx, { idx }, {});
 | ||||
| //
 | ||||
| // is equivalent to
 | ||||
| //
 | ||||
| //      common_sampler_sample(gsmpl, ctx, idx);
 | ||||
| //      common_sampler_accept(gsmpl, token, true);
 | ||||
| //
 | ||||
| // requires: idxs.size() == draft.size() + 1
 | ||||
| //
 | ||||
| // returns at least 1 token, up to idxs.size()
 | ||||
| //
 | ||||
| std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false); | ||||
| 
 | ||||
| // assume idxs == [ 0, 1, 2, ..., draft.size() ]
 | ||||
| std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); | ||||
| 
 | ||||
| uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); | ||||
| 
 | ||||
| // helpers
 | ||||
|  |  | |||
							
								
								
									
										269
									
								
								common/speculative.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										269
									
								
								common/speculative.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,269 @@ | |||
| #include "speculative.h" | ||||
| 
 | ||||
| #include "log.h" | ||||
| #include "common.h" | ||||
| #include "sampling.h" | ||||
| 
 | ||||
| #include <cstring> | ||||
| 
 | ||||
| #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128 | ||||
| #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 | ||||
| 
 | ||||
| struct common_speculative { | ||||
|     struct llama_context * ctx; | ||||
|     struct common_sampler * smpl; | ||||
| 
 | ||||
|     llama_batch batch; | ||||
|     llama_tokens prompt; | ||||
| }; | ||||
| 
 | ||||
| struct common_speculative * common_speculative_init( | ||||
|         struct llama_context * ctx_dft) { | ||||
|     auto * result = new common_speculative { | ||||
|         /* .ctx    = */ ctx_dft, | ||||
|         /* .smpl   = */ nullptr, | ||||
|         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), | ||||
|         /* .prompt = */ {}, | ||||
|     }; | ||||
| 
 | ||||
|     // TODO: optimize or pass from outside?
 | ||||
| #if 0 | ||||
|     { | ||||
|         common_params_sampling params; | ||||
|         params.no_perf = false; | ||||
| 
 | ||||
|         params.top_k = 40; | ||||
|         params.top_p = 0.9; | ||||
| 
 | ||||
|         params.samplers = { | ||||
|             COMMON_SAMPLER_TYPE_TOP_K, | ||||
|             COMMON_SAMPLER_TYPE_TOP_P, | ||||
|             COMMON_SAMPLER_TYPE_INFILL, | ||||
|         }; | ||||
| 
 | ||||
|         result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); | ||||
|     } | ||||
| #else | ||||
|     { | ||||
|         common_params_sampling params; | ||||
|         params.no_perf = false; | ||||
| 
 | ||||
|         params.top_k = 10; | ||||
| 
 | ||||
|         params.samplers = { | ||||
|             COMMON_SAMPLER_TYPE_TOP_K, | ||||
|         }; | ||||
| 
 | ||||
|         result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); | ||||
|     } | ||||
| #endif | ||||
| 
 | ||||
|     return result; | ||||
| } | ||||
| 
 | ||||
| void common_speculative_free(struct common_speculative * spec) { | ||||
|     common_sampler_free(spec->smpl); | ||||
| 
 | ||||
|     llama_batch_free(spec->batch); | ||||
| 
 | ||||
|     delete spec; | ||||
| } | ||||
| 
 | ||||
| bool common_speculative_are_compatible( | ||||
|         const struct llama_context * ctx_tgt, | ||||
|         const struct llama_context * ctx_dft) { | ||||
|     const struct llama_model * model_tgt = llama_get_model(ctx_tgt); | ||||
|     const struct llama_model * model_dft = llama_get_model(ctx_dft); | ||||
| 
 | ||||
|     const bool vocab_type_tgt = llama_vocab_type(model_tgt); | ||||
|     LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); | ||||
| 
 | ||||
|     const bool vocab_type_dft = llama_vocab_type(model_dft); | ||||
|     LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); | ||||
| 
 | ||||
|     if (vocab_type_tgt != vocab_type_dft) { | ||||
|         LOG_ERR("%s: draft model vocab type must match target model to use speculation but " | ||||
|                      "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || | ||||
|         llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || | ||||
|         llama_token_bos(model_tgt) != llama_token_bos(model_dft) || | ||||
|         llama_token_eos(model_tgt) != llama_token_eos(model_dft) | ||||
|     ) { | ||||
|         LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     { | ||||
|         const int n_vocab_tgt = llama_n_vocab(model_tgt); | ||||
|         const int n_vocab_dft = llama_n_vocab(model_dft); | ||||
| 
 | ||||
|         const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); | ||||
| 
 | ||||
|         if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { | ||||
|             LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " | ||||
|                          "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", | ||||
|                     __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { | ||||
|             const char * token_text_tgt = llama_token_get_text(model_tgt, i); | ||||
|             const char * token_text_dft = llama_token_get_text(model_dft, i); | ||||
|             if (std::strcmp(token_text_tgt, token_text_dft) != 0) { | ||||
|                 LOG_ERR("%s: draft model vocab must match target model to use speculation but " | ||||
|                              "token %d content differs - target '%s', draft '%s'\n", __func__, i, | ||||
|                         common_token_to_piece(ctx_tgt, i).c_str(), | ||||
|                         common_token_to_piece(ctx_dft, i).c_str()); | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| llama_tokens common_speculative_gen_draft( | ||||
|         struct common_speculative * spec, | ||||
|         struct common_speculative_params params, | ||||
|         const llama_tokens & prompt_tgt, | ||||
|         llama_token id_last) { | ||||
|     auto & batch  = spec->batch; | ||||
|     auto & ctx    = spec->ctx; | ||||
|     auto & smpl   = spec->smpl; | ||||
|     auto & prompt = spec->prompt; | ||||
| 
 | ||||
|     int reuse_i = 0; | ||||
|     int reuse_n = 0; | ||||
| 
 | ||||
|     const int n_ctx = llama_n_ctx(ctx) - params.n_draft; | ||||
| 
 | ||||
|     const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx); | ||||
| 
 | ||||
|     // reuse as much as possible from the old draft context
 | ||||
|     // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
 | ||||
|     for (int i = 0; i < (int) prompt.size(); ++i) { | ||||
|         int cur = 0; | ||||
|         while (i_start + cur < (int) prompt_tgt.size() && | ||||
|                i       + cur < (int) prompt.size() && | ||||
|                prompt_tgt[i_start + cur] == prompt[i + cur]) { | ||||
|             cur++; | ||||
|         } | ||||
| 
 | ||||
|         if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { | ||||
|             reuse_i = i; | ||||
|             reuse_n = cur; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); | ||||
| 
 | ||||
|     llama_tokens result; | ||||
|     result.reserve(params.n_draft); | ||||
| 
 | ||||
|     if (reuse_n == 0) { | ||||
|         llama_kv_cache_clear(ctx); | ||||
| 
 | ||||
|         prompt.clear(); | ||||
|     } else { | ||||
|         // this happens when a previous draft has been discarded (for example, due to being too small), but the
 | ||||
|         // target model agreed with it. in this case, we simply pass back the previous results to save compute
 | ||||
|         if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { | ||||
|             for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { | ||||
|                 result.push_back(prompt[i]); | ||||
| 
 | ||||
|                 if (params.n_draft <= (int) result.size()) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             return result; | ||||
|         } | ||||
| 
 | ||||
|         if (reuse_i > 0) { | ||||
|             llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); | ||||
|             llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); | ||||
| 
 | ||||
|             prompt.erase(prompt.begin(), prompt.begin() + reuse_i); | ||||
|         } | ||||
| 
 | ||||
|         if (reuse_n < (int) prompt.size()) { | ||||
|             llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); | ||||
| 
 | ||||
|             prompt.erase(prompt.begin() + reuse_n, prompt.end()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // prepare a batch to evaluate any new tokens in the prompt
 | ||||
|     common_batch_clear(batch); | ||||
| 
 | ||||
|     for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { | ||||
|         //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
 | ||||
|         common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); | ||||
| 
 | ||||
|         prompt.push_back(prompt_tgt[i]); | ||||
|     } | ||||
| 
 | ||||
|     // we should rarely end-up here during normal decoding
 | ||||
|     if (batch.n_tokens > 0) { | ||||
|         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
 | ||||
| 
 | ||||
|         llama_decode(ctx, batch); | ||||
|     } | ||||
| 
 | ||||
|     const llama_pos n_past = prompt.size(); | ||||
| 
 | ||||
|     LOG_DBG("%s: n_past = %d\n", __func__, n_past); | ||||
| 
 | ||||
|     common_batch_clear(batch); | ||||
|     common_batch_add  (batch, id_last, n_past, { 0 }, true); | ||||
| 
 | ||||
|     prompt.push_back(id_last); | ||||
| 
 | ||||
|     //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
 | ||||
| 
 | ||||
|     llama_decode(ctx, batch); | ||||
| 
 | ||||
|     common_sampler_reset(smpl); | ||||
| 
 | ||||
|     // sample n_draft tokens from the draft model
 | ||||
|     for (int i = 0; i < params.n_draft; ++i) { | ||||
|         common_batch_clear(batch); | ||||
| 
 | ||||
|         common_sampler_sample(smpl, ctx, 0, true); | ||||
| 
 | ||||
|         const auto * cur_p = common_sampler_get_candidates(smpl); | ||||
| 
 | ||||
|         for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { | ||||
|             LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", | ||||
|                     k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); | ||||
|         } | ||||
| 
 | ||||
|         // add drafted token for each sequence
 | ||||
|         const llama_token id = cur_p->data[0].id; | ||||
| 
 | ||||
|         // only collect very high-confidence draft tokens
 | ||||
|         if (cur_p->data[0].p < params.p_min) { | ||||
|             break; | ||||
|         } | ||||
| 
 | ||||
|         common_sampler_accept(smpl, id, true); | ||||
| 
 | ||||
|         result.push_back(id); | ||||
| 
 | ||||
|         if (params.n_draft <= (int) result.size()) { | ||||
|             break; | ||||
|         } | ||||
| 
 | ||||
|         common_batch_add(batch, id, n_past + i + 1, { 0 }, true); | ||||
| 
 | ||||
|         // evaluate the drafted tokens on the draft model
 | ||||
|         llama_decode(ctx, batch); | ||||
| 
 | ||||
|         prompt.push_back(id); | ||||
|     } | ||||
| 
 | ||||
|     return result; | ||||
| } | ||||
							
								
								
									
										28
									
								
								common/speculative.h
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								common/speculative.h
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | |||
| #pragma once | ||||
| 
 | ||||
| #include "llama.h" | ||||
| #include "common.h" | ||||
| 
 | ||||
| struct common_speculative; | ||||
| 
 | ||||
| struct common_speculative_params { | ||||
|     int n_draft = 16;  // max drafted tokens
 | ||||
|     int n_reuse = 256; | ||||
| 
 | ||||
|     float p_min = 0.9f; // min probabiliy required to accept a token in the draft
 | ||||
| }; | ||||
| 
 | ||||
| struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); | ||||
| 
 | ||||
| void common_speculative_free(struct common_speculative * spec); | ||||
| 
 | ||||
| bool common_speculative_are_compatible( | ||||
|         const struct llama_context * ctx_tgt, | ||||
|         const struct llama_context * ctx_dft); | ||||
| 
 | ||||
| // sample up to n_draft tokens and add them to the batch using the draft model
 | ||||
| llama_tokens common_speculative_gen_draft( | ||||
|                struct common_speculative * spec, | ||||
|         struct common_speculative_params   params, | ||||
|                       const llama_tokens & prompt, | ||||
|                              llama_token   id_last); | ||||
|  | @ -50,5 +50,6 @@ else() | |||
|     add_subdirectory(simple) | ||||
|     add_subdirectory(simple-chat) | ||||
|     add_subdirectory(speculative) | ||||
|     add_subdirectory(speculative-simple) | ||||
|     add_subdirectory(tokenize) | ||||
| endif() | ||||
|  |  | |||
|  | @ -68,10 +68,10 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     llama_sampler * smpl = llama_sampler_chain_init(sparams); | ||||
| 
 | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); | ||||
| 
 | ||||
|     if (ctx == NULL) { | ||||
|         LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); | ||||
|  |  | |||
|  | @ -73,7 +73,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     common_init(); | ||||
| 
 | ||||
|     auto & sparams = params.sparams; | ||||
|     auto & sparams = params.sampling; | ||||
| 
 | ||||
|     console::init(params.simple_io, params.use_color); | ||||
|     atexit([]() { console::cleanup(); }); | ||||
|  |  | |||
|  | @ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ | |||
| 
 | ||||
|     LOG("\n"); | ||||
| 
 | ||||
|     struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); | ||||
|     struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); | ||||
|     if (!smpl) { | ||||
|         LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); | ||||
|         exit(1); | ||||
|  |  | |||
|  | @ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm | |||
| 
 | ||||
|     LOG_INF("\n"); | ||||
| 
 | ||||
|     struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams); | ||||
|     struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); | ||||
|     return smpl; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -115,7 +115,7 @@ int main(int argc, char ** argv) { | |||
|     llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); | ||||
| 
 | ||||
|     // target model sampling context
 | ||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sparams); | ||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sampling); | ||||
| 
 | ||||
|     // verification n-grams
 | ||||
|     std::vector<ngram_data> ngrams_cur(G); | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ int main(int argc, char ** argv){ | |||
| 
 | ||||
|     common_init(); | ||||
| 
 | ||||
|     const int n_draft = params.n_draft; | ||||
|     const int n_draft = params.speculative.n_max; | ||||
| 
 | ||||
|     // init llama.cpp
 | ||||
|     llama_backend_init(); | ||||
|  | @ -40,6 +40,7 @@ int main(int argc, char ** argv){ | |||
|     common_ngram_cache ngram_cache_context; | ||||
|     common_ngram_cache ngram_cache_dynamic; | ||||
|     common_ngram_cache ngram_cache_static; | ||||
| 
 | ||||
|     int64_t t_draft_flat_us = 0; | ||||
|     int64_t t_draft_us = 0; | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ int main(int argc, char ** argv){ | |||
|     common_init(); | ||||
| 
 | ||||
|     // max. number of additional tokens to draft if match is found
 | ||||
|     const int n_draft = params.n_draft; | ||||
|     const int n_draft = params.speculative.n_max; | ||||
| 
 | ||||
|     const bool dump_kv_cache = params.dump_kv_cache; | ||||
| 
 | ||||
|  | @ -102,7 +102,7 @@ int main(int argc, char ** argv){ | |||
| 
 | ||||
|     bool has_eos = false; | ||||
| 
 | ||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sparams); | ||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sampling); | ||||
| 
 | ||||
|     std::vector<llama_token> draft; | ||||
| 
 | ||||
|  |  | |||
|  | @ -100,7 +100,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     common_init(); | ||||
| 
 | ||||
|     auto & sparams = params.sparams; | ||||
|     auto & sparams = params.sampling; | ||||
| 
 | ||||
|     // save choice to use color for later
 | ||||
|     // (note for later: this is a slightly awkward choice)
 | ||||
|  |  | |||
|  | @ -160,7 +160,7 @@ int main(int argc, char ** argv) { | |||
|     for (size_t i = 0; i < clients.size(); ++i) { | ||||
|         auto & client = clients[i]; | ||||
|         client.id = i; | ||||
|         client.smpl = common_sampler_init(model, params.sparams); | ||||
|         client.smpl = common_sampler_init(model, params.sampling); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<llama_token> tokens_system; | ||||
|  |  | |||
|  | @ -282,8 +282,8 @@ int main(int argc, char ** argv) { | |||
|                 return a.second > b.second; | ||||
|             }); | ||||
| 
 | ||||
|             LOG("Top %d similar chunks:\n", params.sparams.top_k); | ||||
|             for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { | ||||
|             LOG("Top %d similar chunks:\n", params.sampling.top_k); | ||||
|             for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) { | ||||
|                 LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str()); | ||||
|                 LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); | ||||
|                 LOG("similarity: %f\n", similarities[i].second); | ||||
|  |  | |||
|  | @ -9,7 +9,7 @@ int main(int argc, char ** argv) { | |||
|     common_params params; | ||||
| 
 | ||||
|     params.prompt = "The quick brown fox"; | ||||
|     params.sparams.seed = 1234; | ||||
|     params.sampling.seed = 1234; | ||||
| 
 | ||||
|     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { | ||||
|         return 1; | ||||
|  | @ -42,7 +42,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     llama_sampler * smpl = llama_sampler_chain_init(sparams); | ||||
| 
 | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); | ||||
|     llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); | ||||
| 
 | ||||
|     // tokenize prompt
 | ||||
|     auto tokens = common_tokenize(ctx, params.prompt, true); | ||||
|  | @ -106,7 +106,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     llama_sampler * smpl2 = llama_sampler_chain_init(sparams); | ||||
| 
 | ||||
|     llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); | ||||
|     llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed)); | ||||
| 
 | ||||
|     printf("\nsecond run: %s", params.prompt.c_str()); | ||||
| 
 | ||||
|  | @ -169,7 +169,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     llama_sampler * smpl3 = llama_sampler_chain_init(sparams); | ||||
| 
 | ||||
|     llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); | ||||
|     llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); | ||||
| 
 | ||||
|     printf("\nsingle seq run: %s", params.prompt.c_str()); | ||||
| 
 | ||||
|  |  | |||
|  | @ -175,7 +175,7 @@ struct server_slot { | |||
|     // sampling
 | ||||
|     json json_schema; | ||||
| 
 | ||||
|     struct common_sampler_params sparams; | ||||
|     struct common_params_sampling sparams; | ||||
|     struct common_sampler * smpl = nullptr; | ||||
| 
 | ||||
|     llama_token sampled; | ||||
|  | @ -687,7 +687,7 @@ struct server_context { | |||
| 
 | ||||
|             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); | ||||
| 
 | ||||
|             slot.sparams = params.sparams; | ||||
|             slot.sparams = params.sampling; | ||||
| 
 | ||||
|             slot.callback_on_release = [this](int) { | ||||
|                 queue_tasks.pop_deferred_task(); | ||||
|  | @ -743,7 +743,7 @@ struct server_context { | |||
|                 } | ||||
| 
 | ||||
|                 // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
 | ||||
|                 int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens); | ||||
|                 int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); | ||||
| 
 | ||||
|                 // fraction of the common subsequence length compared to the current slot's prompt length
 | ||||
|                 float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size()); | ||||
|  | @ -788,7 +788,7 @@ struct server_context { | |||
|     bool launch_slot_with_task(server_slot & slot, const server_task & task) { | ||||
|         slot_params default_params; | ||||
|         // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
 | ||||
|         auto default_sparams = params.sparams; | ||||
|         auto default_sparams = params.sampling; | ||||
|         const auto & data = task.data; | ||||
| 
 | ||||
|         if (data.count("__oaicompat") != 0) { | ||||
|  | @ -1960,7 +1960,7 @@ struct server_context { | |||
| 
 | ||||
|                             if (slot.params.cache_prompt) { | ||||
|                                 // reuse any previously computed tokens that are common with the new prompt
 | ||||
|                                 slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); | ||||
|                                 slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); | ||||
| 
 | ||||
|                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
 | ||||
|                                 if (params.n_cache_reuse > 0) { | ||||
|  |  | |||
|  | @ -24,7 +24,6 @@ | |||
| #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" | ||||
| 
 | ||||
| using json = nlohmann::ordered_json; | ||||
| using llama_tokens = std::vector<llama_token>; | ||||
| 
 | ||||
| #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) | ||||
| #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) | ||||
|  | @ -439,62 +438,6 @@ static std::string gen_chatcmplid() { | |||
| // other common utils
 | ||||
| //
 | ||||
| 
 | ||||
| static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) { | ||||
|     size_t i; | ||||
|     for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} | ||||
| 
 | ||||
|     return i; | ||||
| } | ||||
| 
 | ||||
| static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) { | ||||
|     // check for empty sequences
 | ||||
|     if (a.empty() || b.empty()) { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     // get the lengths of the input sequences
 | ||||
|     size_t a_len = a.size(); | ||||
|     size_t b_len = b.size(); | ||||
| 
 | ||||
|     // initialize the maximum length of the longest common subsequence (LCS)
 | ||||
|     size_t max_length = 0; | ||||
| 
 | ||||
|     // use two rows instead of a 2D matrix to optimize space
 | ||||
|     std::vector<size_t> prev_row(b_len + 1, 0); | ||||
|     std::vector<size_t> curr_row(b_len + 1, 0); | ||||
| 
 | ||||
|     // iterate through the elements of a
 | ||||
|     for (size_t i = 1; i <= a_len; i++) { | ||||
|         // iterate through the elements of b
 | ||||
|         for (size_t j = 1; j <= b_len; j++) { | ||||
|             // if elements at the current positions match
 | ||||
|             if (a[i - 1] == b[j - 1]) { | ||||
|                 // if it's the first element of either sequences, set LCS length to 1
 | ||||
|                 if (i == 1 || j == 1) { | ||||
|                     curr_row[j] = 1; | ||||
|                 } else { | ||||
|                     // increment LCS length by 1 compared to the previous element
 | ||||
|                     curr_row[j] = prev_row[j - 1] + 1; | ||||
|                 } | ||||
| 
 | ||||
|                 // update max_length if necessary
 | ||||
|                 if (curr_row[j] > max_length) { | ||||
|                     max_length = curr_row[j]; | ||||
|                 } | ||||
|             } else { | ||||
|                 // reset LCS length if elements don't match
 | ||||
|                 curr_row[j] = 0; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // update the previous row for the next iteration
 | ||||
|         prev_row = curr_row; | ||||
|     } | ||||
| 
 | ||||
|     // return the maximum length of the LCS
 | ||||
|     return max_length; | ||||
| } | ||||
| 
 | ||||
| static bool ends_with(const std::string & str, const std::string & suffix) { | ||||
|     return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); | ||||
| } | ||||
|  |  | |||
							
								
								
									
										5
									
								
								examples/speculative-simple/CMakeLists.txt
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								examples/speculative-simple/CMakeLists.txt
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | |||
| set(TARGET llama-speculative-simple) | ||||
| add_executable(${TARGET} speculative-simple.cpp) | ||||
| install(TARGETS ${TARGET} RUNTIME) | ||||
| target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||||
| target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||||
							
								
								
									
										12
									
								
								examples/speculative-simple/README.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								examples/speculative-simple/README.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,12 @@ | |||
| # llama.cpp/examples/speculative-simple | ||||
| 
 | ||||
| Demonstration of basic greedy speculative decoding | ||||
| 
 | ||||
| ```bash | ||||
| ./bin/llama-speculative-simple \ | ||||
|     -m  ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ | ||||
|     -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ | ||||
|     -f test.txt -c 0 -ngl 99 --color \ | ||||
|     --sampling-seq k --top-k 1 -fa --temp 0.0 \ | ||||
|     -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 | ||||
| ``` | ||||
							
								
								
									
										273
									
								
								examples/speculative-simple/speculative-simple.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								examples/speculative-simple/speculative-simple.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,273 @@ | |||
| #include "arg.h" | ||||
| #include "common.h" | ||||
| #include "sampling.h" | ||||
| #include "speculative.h" | ||||
| #include "log.h" | ||||
| #include "llama.h" | ||||
| 
 | ||||
| #include <cstdio> | ||||
| #include <cstring> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| int main(int argc, char ** argv) { | ||||
|     common_params params; | ||||
| 
 | ||||
|     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     if (params.n_predict < -1) { | ||||
|         LOG_ERR("%s: --n-predict must be >= -1\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     common_init(); | ||||
| 
 | ||||
|     if (params.speculative.model.empty()) { | ||||
|         LOG_ERR("%s: --model-draft is required\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     // init llama.cpp
 | ||||
|     llama_backend_init(); | ||||
|     llama_numa_init(params.numa); | ||||
| 
 | ||||
|     llama_model * model_tgt = NULL; | ||||
|     llama_model * model_dft = NULL; | ||||
| 
 | ||||
|     llama_context * ctx_tgt = NULL; | ||||
|     llama_context * ctx_dft = NULL; | ||||
| 
 | ||||
|     // load the target model
 | ||||
|     common_init_result llama_init_tgt = common_init_from_params(params); | ||||
| 
 | ||||
|     model_tgt = llama_init_tgt.model; | ||||
|     ctx_tgt   = llama_init_tgt.context; | ||||
| 
 | ||||
|     // load the draft model
 | ||||
|     params.model        = params.speculative.model; | ||||
|     params.n_ctx        = params.speculative.n_ctx; | ||||
|     params.n_batch      = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; | ||||
|     params.n_gpu_layers = params.speculative.n_gpu_layers; | ||||
| 
 | ||||
|     if (params.speculative.cpuparams.n_threads > 0) { | ||||
|         params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; | ||||
|     } | ||||
| 
 | ||||
|     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; | ||||
|     common_init_result llama_init_dft = common_init_from_params(params); | ||||
| 
 | ||||
|     model_dft = llama_init_dft.model; | ||||
|     ctx_dft   = llama_init_dft.context; | ||||
| 
 | ||||
|     if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     // Tokenize the prompt
 | ||||
|     std::vector<llama_token> inp; | ||||
|     inp = common_tokenize(ctx_tgt, params.prompt, true, true); | ||||
| 
 | ||||
|     if (llama_n_ctx(ctx_tgt) < (int) inp.size()) { | ||||
|         LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); | ||||
| 
 | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     if (llama_n_batch(ctx_tgt) < (int) inp.size()) { | ||||
|         LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); | ||||
| 
 | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     LOG("\n\n"); | ||||
| 
 | ||||
|     for (auto id : inp) { | ||||
|         LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); | ||||
|     } | ||||
| 
 | ||||
|     // how many tokens to draft each time
 | ||||
|     int n_draft     = params.speculative.n_max; | ||||
|     int n_draft_min = params.speculative.n_min; | ||||
| 
 | ||||
|     float p_min = params.speculative.p_min; | ||||
| 
 | ||||
|     int n_predict = 0; | ||||
|     int n_drafted = 0; | ||||
|     int n_accept  = 0; | ||||
| 
 | ||||
|     // used to determine end of generation
 | ||||
|     bool has_eos = false; | ||||
| 
 | ||||
|     // ================================================
 | ||||
|     // everything until here is standard initialization
 | ||||
|     // the relevant stuff for speculative decoding starts here
 | ||||
| 
 | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
| 
 | ||||
|     // target model sampling context
 | ||||
|     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); | ||||
| 
 | ||||
|     // eval the prompt
 | ||||
|     llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); | ||||
| 
 | ||||
|     // note: keep the last token separate!
 | ||||
|     llama_token id_last = inp.back(); | ||||
| 
 | ||||
|     // all tokens currently in the target context
 | ||||
|     auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1); | ||||
| 
 | ||||
|     int n_past = inp.size() - 1; | ||||
| 
 | ||||
|     // init the speculator
 | ||||
|     struct common_speculative_params params_spec; | ||||
|     params_spec.n_draft = n_draft; | ||||
|     params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; | ||||
|     params_spec.p_min   = p_min; | ||||
| 
 | ||||
|     struct common_speculative * spec = common_speculative_init(ctx_dft); | ||||
| 
 | ||||
|     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); | ||||
| 
 | ||||
|     const auto t_enc_end = ggml_time_us(); | ||||
| 
 | ||||
|     const auto t_dec_start = ggml_time_us(); | ||||
| 
 | ||||
|     while (true) { | ||||
|         // optionally, generate draft tokens that can be appended to the target batch
 | ||||
|         //
 | ||||
|         // this is the most important part of the speculation. the more probable tokens that are provided here
 | ||||
|         // the better the performance will be. in theory, this computation can be performed asynchronously and even
 | ||||
|         // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
 | ||||
|         // from a cache or lookup tables.
 | ||||
|         //
 | ||||
|         llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); | ||||
| 
 | ||||
|         //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
 | ||||
| 
 | ||||
|         // always have a token to evaluate from before - id_last
 | ||||
|         common_batch_clear(batch_tgt); | ||||
|         common_batch_add  (batch_tgt, id_last, n_past++, { 0 }, true); | ||||
| 
 | ||||
|         // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
 | ||||
|         { | ||||
|             // do not waste time on small drafts
 | ||||
|             if (draft.size() < n_draft_min) { | ||||
|                 draft.clear(); | ||||
|             } | ||||
| 
 | ||||
|             for (size_t i = 0; i < draft.size(); ++i) { | ||||
|                 common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); | ||||
|             } | ||||
| 
 | ||||
|             //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
 | ||||
| 
 | ||||
|             llama_decode(ctx_tgt, batch_tgt); | ||||
|         } | ||||
| 
 | ||||
|         // sample from the full target batch and return the accepted tokens based on the target sampler
 | ||||
|         //
 | ||||
|         // for each token to be accepted, the sampler would have to sample that same token
 | ||||
|         // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
 | ||||
|         // available logits from the batch and sample the next token until we run out of logits or the sampler
 | ||||
|         // disagrees with the draft
 | ||||
|         //
 | ||||
|         const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); | ||||
| 
 | ||||
|         //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
 | ||||
| 
 | ||||
|         GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
 | ||||
| 
 | ||||
|         n_past    += ids.size() - 1; | ||||
|         n_drafted += batch_tgt.n_tokens - 1; | ||||
|         n_accept  += ids.size() - 1; | ||||
| 
 | ||||
|         // process the accepted tokens and update contexts
 | ||||
|         //
 | ||||
|         // this is the standard token post-processing that we normally do
 | ||||
|         // in this case, we do it for a group of accepted tokens at once
 | ||||
|         //
 | ||||
|         { | ||||
|             llama_token id; | ||||
|             std::string token_str; | ||||
| 
 | ||||
|             for (size_t i = 0; i < ids.size(); ++i) { | ||||
|                 id = ids[i]; | ||||
| 
 | ||||
|                 ++n_predict; | ||||
| 
 | ||||
|                 if (llama_token_is_eog(model_tgt, id)) { | ||||
|                     has_eos = true; | ||||
|                     break; | ||||
|                 } | ||||
| 
 | ||||
|                 token_str = common_token_to_piece(ctx_tgt, id); | ||||
| 
 | ||||
|                 if (params.use_color && i + 1 < ids.size()) { | ||||
|                     LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); | ||||
|                 } else { | ||||
|                     LOG("%s", token_str.c_str()); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { | ||||
|                 break; | ||||
|             } | ||||
| 
 | ||||
|             LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str()); | ||||
| 
 | ||||
|             { | ||||
|                 LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); | ||||
| 
 | ||||
|                 llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); | ||||
|             } | ||||
| 
 | ||||
|             prompt_tgt.push_back(id_last); | ||||
|             prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1); | ||||
| 
 | ||||
|             // remember the last accepted token for the next iteration
 | ||||
|             id_last = id; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     auto t_dec_end = ggml_time_us(); | ||||
| 
 | ||||
|     const int n_input = inp.size(); | ||||
| 
 | ||||
|     LOG("\n\n"); | ||||
| 
 | ||||
|     LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); | ||||
|     LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f)); | ||||
| 
 | ||||
|     LOG_INF("\n"); | ||||
|     LOG_INF("n_draft   = %d\n", n_draft); | ||||
|     LOG_INF("n_predict = %d\n", n_predict); | ||||
|     LOG_INF("n_drafted = %d\n", n_drafted); | ||||
|     LOG_INF("n_accept  = %d\n", n_accept); | ||||
|     LOG_INF("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted); | ||||
| 
 | ||||
|     LOG_INF("\n"); | ||||
|     LOG_INF("draft:\n\n"); | ||||
| 
 | ||||
|     llama_perf_context_print(ctx_dft); | ||||
| 
 | ||||
|     LOG_INF("\n"); | ||||
|     LOG_INF("target:\n\n"); | ||||
|     common_perf_print(ctx_tgt, smpl); | ||||
| 
 | ||||
|     common_sampler_free(smpl); | ||||
|     common_speculative_free(spec); | ||||
| 
 | ||||
|     llama_free(ctx_tgt); | ||||
|     llama_free_model(model_tgt); | ||||
| 
 | ||||
|     llama_free(ctx_dft); | ||||
|     llama_free_model(model_dft); | ||||
| 
 | ||||
|     llama_backend_free(); | ||||
| 
 | ||||
|     LOG("\n\n"); | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
|  | @ -12,7 +12,7 @@ | |||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100 | ||||
| #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128 | ||||
| #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 | ||||
| 
 | ||||
| struct seq_draft { | ||||
|  | @ -33,7 +33,7 @@ int main(int argc, char ** argv) { | |||
|     common_params params; | ||||
| 
 | ||||
|     // needed to get candidate probs even for temp <= 0.0
 | ||||
|     params.sparams.n_probs = 128; | ||||
|     params.sampling.n_probs = 128; | ||||
| 
 | ||||
|     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { | ||||
|         return 1; | ||||
|  | @ -46,7 +46,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     common_init(); | ||||
| 
 | ||||
|     if (params.model_draft.empty()) { | ||||
|     if (params.speculative.model.empty()) { | ||||
|         LOG_ERR("%s: --model-draft is required\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
|  | @ -55,9 +55,9 @@ int main(int argc, char ** argv) { | |||
|     const int n_seq_dft = params.n_parallel; | ||||
| 
 | ||||
|     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 | ||||
|     const float p_split  = params.p_split; | ||||
|     const float p_draft_split = params.speculative.p_split; | ||||
| 
 | ||||
|     std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed); | ||||
|     std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed); | ||||
|     std::uniform_real_distribution<> u_dist; | ||||
| 
 | ||||
|     // init llama.cpp
 | ||||
|  | @ -76,13 +76,13 @@ int main(int argc, char ** argv) { | |||
|     ctx_tgt = llama_init_tgt.context; | ||||
| 
 | ||||
|     // load the draft model
 | ||||
|     params.model = params.model_draft; | ||||
|     params.n_gpu_layers = params.n_gpu_layers_draft; | ||||
|     if (params.draft_cpuparams.n_threads > 0) { | ||||
|         params.cpuparams.n_threads = params.draft_cpuparams.n_threads; | ||||
|     params.model = params.speculative.model; | ||||
|     params.n_gpu_layers = params.speculative.n_gpu_layers; | ||||
|     if (params.speculative.cpuparams.n_threads > 0) { | ||||
|         params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; | ||||
|     } | ||||
| 
 | ||||
|     params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; | ||||
|     params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; | ||||
|     common_init_result llama_init_dft = common_init_from_params(params); | ||||
|     model_dft = llama_init_dft.model; | ||||
|     ctx_dft = llama_init_dft.context; | ||||
|  | @ -170,7 +170,7 @@ int main(int argc, char ** argv) { | |||
|     //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
 | ||||
| 
 | ||||
|     // how many tokens to draft each time
 | ||||
|     int n_draft = params.n_draft; | ||||
|     int n_draft = params.speculative.n_max; | ||||
| 
 | ||||
|     int n_predict = 0; | ||||
|     int n_drafted = 0; | ||||
|  | @ -183,14 +183,14 @@ int main(int argc, char ** argv) { | |||
|     bool has_eos = false; | ||||
| 
 | ||||
|     // target model sampling context (reuse the llama_context's sampling instance)
 | ||||
|     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); | ||||
|     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); | ||||
| 
 | ||||
|     // draft sequence data
 | ||||
|     std::vector<seq_draft> drafts(n_seq_dft); | ||||
| 
 | ||||
|     for (int s = 0; s < n_seq_dft; ++s) { | ||||
|         // allocate llama_sampler for each draft sequence
 | ||||
|         drafts[s].smpl = common_sampler_init(model_dft, params.sparams); | ||||
|         drafts[s].smpl = common_sampler_init(model_dft, params.sampling); | ||||
|     } | ||||
| 
 | ||||
|     llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); | ||||
|  | @ -230,7 +230,7 @@ int main(int argc, char ** argv) { | |||
|             // for stochastic sampling, attempt to match the token with the drafted tokens
 | ||||
|             { | ||||
|                 bool accept = false; | ||||
|                 if (params.sparams.temp > 0) { | ||||
|                 if (params.sampling.temp > 0) { | ||||
|                     // stochastic verification
 | ||||
|                     common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); | ||||
| 
 | ||||
|  | @ -494,7 +494,7 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|                 // attempt to split the branch if the probability is high enough
 | ||||
|                 for (int f = 1; f < 8; ++f) { | ||||
|                     if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { | ||||
|                     if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { | ||||
|                         LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); | ||||
| 
 | ||||
|                         llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1); | ||||
|  |  | |||
|  | @ -70,7 +70,7 @@ int main(void) { | |||
| 
 | ||||
|     // non-existence arg in specific example (--draft cannot be used outside llama-speculative)
 | ||||
|     argv = {"binary_name", "--draft", "123"}; | ||||
|     assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SERVER)); | ||||
|     assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING)); | ||||
| 
 | ||||
| 
 | ||||
|     printf("test-arg-parser: test valid usage\n\n"); | ||||
|  | @ -96,7 +96,7 @@ int main(void) { | |||
|     // --draft cannot be used outside llama-speculative
 | ||||
|     argv = {"binary_name", "--draft", "123"}; | ||||
|     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); | ||||
|     assert(params.n_draft == 123); | ||||
|     assert(params.speculative.n_max == 123); | ||||
| 
 | ||||
| // skip this part on windows, because setenv is not supported
 | ||||
| #ifdef _WIN32 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue