sampling : avoid expensive softmax during greedy sampling (#9605)
* sampling : avoid expensive softmax during greedy sampling ggml-ci * speculative : fix default RNG seed + set sparams.n_probs * Update tests/test-sampling.cpp Co-authored-by: slaren <slarengh@gmail.com> * sampling : add clarifying comment [no ci] --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		
							parent
							
								
									c087b6f11d
								
							
						
					
					
						commit
						b0f27361f3
					
				
					 5 changed files with 59 additions and 6 deletions
				
			
		|  | @ -209,7 +209,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st | ||||||
|             GGML_ASSERT(false && "unknown mirostat version"); |             GGML_ASSERT(false && "unknown mirostat version"); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); |         if (params.n_probs > 0) { | ||||||
|  |             // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
 | ||||||
|  |             // ref: https://github.com/ggerganov/llama.cpp/pull/9605
 | ||||||
|  |             //
 | ||||||
|  |             // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
 | ||||||
|  |             // it is much faster, since we avoid sorting all tokens and should give a good approximation
 | ||||||
|  |             llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs)); | ||||||
|  |             llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); | ||||||
|  |         } | ||||||
|         llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); |         llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -32,6 +32,9 @@ struct seq_draft { | ||||||
| int main(int argc, char ** argv) { | int main(int argc, char ** argv) { | ||||||
|     gpt_params params; |     gpt_params params; | ||||||
| 
 | 
 | ||||||
|  |     // needed to get candidate probs even for temp <= 0.0
 | ||||||
|  |     params.sparams.n_probs = 128; | ||||||
|  | 
 | ||||||
|     if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { |     if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
|  | @ -49,7 +52,7 @@ int main(int argc, char ** argv) { | ||||||
|     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 |     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 | ||||||
|     const float p_split  = params.p_split; |     const float p_split  = params.p_split; | ||||||
| 
 | 
 | ||||||
|     std::default_random_engine rng(params.sparams.seed); |     std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed); | ||||||
|     std::uniform_real_distribution<> u_dist; |     std::uniform_real_distribution<> u_dist; | ||||||
| 
 | 
 | ||||||
|     // init llama.cpp
 |     // init llama.cpp
 | ||||||
|  |  | ||||||
|  | @ -1066,6 +1066,7 @@ extern "C" { | ||||||
|     LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed); |     LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed); | ||||||
| 
 | 
 | ||||||
|     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 |     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 | ||||||
|  |     /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
 | ||||||
|     LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void); |     LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void); | ||||||
| 
 | 
 | ||||||
|     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 |     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 | ||||||
|  |  | ||||||
|  | @ -3,13 +3,14 @@ | ||||||
| #include "llama-vocab.h" | #include "llama-vocab.h" | ||||||
| #include "llama-grammar.h" | #include "llama-grammar.h" | ||||||
| 
 | 
 | ||||||
| #include <cassert> |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cstring> | #include <cassert> | ||||||
| #include <ctime> |  | ||||||
| #include <cfloat> | #include <cfloat> | ||||||
| #include <chrono> | #include <chrono> | ||||||
| #include <cmath> | #include <cmath> | ||||||
|  | #include <cstdlib> | ||||||
|  | #include <cstring> | ||||||
|  | #include <ctime> | ||||||
| #include <numeric> | #include <numeric> | ||||||
| #include <random> | #include <random> | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
|  |  | ||||||
|  | @ -1,6 +1,5 @@ | ||||||
| #include "ggml.h" | #include "ggml.h" | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
| #include "llama-sampling.h" |  | ||||||
| 
 | 
 | ||||||
| #ifdef NDEBUG | #ifdef NDEBUG | ||||||
| #undef NDEBUG | #undef NDEBUG | ||||||
|  | @ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler | ||||||
|            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); |            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) { | ||||||
|  |     std::vector<llama_token_data> cur(data.size()); | ||||||
|  |     std::copy(data.begin(), data.end(), cur.begin()); | ||||||
|  |     llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; | ||||||
|  |     llama_sampler_apply(cnstr, &cur_p); | ||||||
|  |     llama_sampler_reset(cnstr); | ||||||
|  |     const int64_t t_start = ggml_time_us(); | ||||||
|  |     for (int i = 0; i < n_iter; i++) { | ||||||
|  |         std::copy(data.begin(), data.end(), cur.begin()); | ||||||
|  |         llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; | ||||||
|  |         llama_sampler_apply(cnstr, &cur_p); | ||||||
|  |         llama_sampler_reset(cnstr); | ||||||
|  |     } | ||||||
|  |     const int64_t t_end = ggml_time_us(); | ||||||
|  |     llama_sampler_free(cnstr); | ||||||
|  |     printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter)) | ||||||
|  | 
 | ||||||
|  | static void test_perf() { | ||||||
|  |     const int n_vocab = 1 << 17; | ||||||
|  | 
 | ||||||
|  |     std::vector<llama_token_data> data; | ||||||
|  | 
 | ||||||
|  |     data.reserve(n_vocab); | ||||||
|  |     for (int i = 0; i < n_vocab; i++) { | ||||||
|  |         const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f); | ||||||
|  |         data.emplace_back(llama_token_data{i, logit, 0.0f}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     BENCH(llama_sampler_init_top_k    (40),      data, 32); | ||||||
|  |     BENCH(llama_sampler_init_top_p    (0.8f, 1), data, 32); | ||||||
|  |     BENCH(llama_sampler_init_min_p    (0.2f, 1), data, 32); | ||||||
|  |     BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32); | ||||||
|  |     BENCH(llama_sampler_init_typical  (0.5f, 1), data, 32); | ||||||
|  |     BENCH(llama_sampler_init_softmax  (),        data, 32); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| int main(void) { | int main(void) { | ||||||
|     ggml_time_init(); |     ggml_time_init(); | ||||||
| 
 | 
 | ||||||
|  | @ -316,5 +354,7 @@ int main(void) { | ||||||
| 
 | 
 | ||||||
|     printf("OK\n"); |     printf("OK\n"); | ||||||
| 
 | 
 | ||||||
|  |     test_perf(); | ||||||
|  | 
 | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue