speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
		
							parent
							
								
									c67fe68e41
								
							
						
					
					
						commit
						0e89203b51
					
				
					 21 changed files with 737 additions and 578 deletions
				
			
		|  | @ -2,6 +2,8 @@ | |||
| 
 | ||||
| #include "llama.h" | ||||
| 
 | ||||
| #include "grammar-parser.h" | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include <unordered_map> | ||||
|  | @ -34,75 +36,64 @@ typedef struct llama_sampling_params { | |||
| 
 | ||||
| } llama_sampling_params; | ||||
| 
 | ||||
| // per-sequence sampler context
 | ||||
| typedef struct llama_sampler_sequence_context { | ||||
|     float mirostat_mu; // mirostat sampler state
 | ||||
|     llama_grammar * grammar; | ||||
| } llama_sampler_sequence_context; | ||||
| 
 | ||||
| // general sampler context
 | ||||
| typedef struct llama_sampling_context { | ||||
|     ~llama_sampling_context(); | ||||
| 
 | ||||
|     // parameters that will be used for sampling and when creating
 | ||||
|     // new llama_sampler_sequence_context instances
 | ||||
| // TODO: move to llama.h
 | ||||
| struct llama_sampling_context { | ||||
|     // parameters that will be used for sampling
 | ||||
|     llama_sampling_params params; | ||||
| 
 | ||||
|     // map of sequence ids to sampler contexts
 | ||||
|     std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts; | ||||
|     // mirostat sampler state
 | ||||
|     float mirostat_mu; | ||||
| 
 | ||||
|     // when non-NULL, new instances of llama_sampler_sequence_context
 | ||||
|     // will get a copy of the grammar here
 | ||||
|     // note: only the pointer is stored here, it is not a copy of
 | ||||
|     //       the grammar and shouldn't be freed
 | ||||
|     llama_grammar * grammar; | ||||
| } llama_sampling_context; | ||||
| 
 | ||||
|     // internal
 | ||||
|     grammar_parser::parse_state parsed_grammar; | ||||
| 
 | ||||
|     // TODO: replace with ring-buffer
 | ||||
|     std::vector<llama_token>      prev; | ||||
|     std::vector<llama_token_data> cur; | ||||
| }; | ||||
| 
 | ||||
| #include "common.h" | ||||
| 
 | ||||
| // Create a new sampling context instance.
 | ||||
| llama_sampling_context llama_sampling_context_init( | ||||
|         const struct gpt_params & params, | ||||
|                   llama_grammar * grammar = NULL); | ||||
| struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); | ||||
| 
 | ||||
| // Fetches the sampler context for the specified sequence id (defaults to 0).
 | ||||
| // If the context for that sequence id doesn't already exist, it will be created with
 | ||||
| // default values based on the parameters in the ctx_sampling argument.
 | ||||
| llama_sampler_sequence_context & llama_sampling_get_sequence_context( | ||||
|               llama_sampling_context & ctx_sampling, | ||||
|         const llama_seq_id             seq = 0); | ||||
| void llama_sampling_free(struct llama_sampling_context * ctx); | ||||
| 
 | ||||
| // Reset the sampler context for the supplied sequence id (defaults to 0).
 | ||||
| // This is necessary to reuse a sequence id or free memory used by sequences
 | ||||
| // that are no longer required.
 | ||||
| bool llama_sampling_context_reset( | ||||
|               llama_sampling_context & ctx_sampling, | ||||
|         const llama_seq_id             seq = 0); | ||||
| // Reset the sampler context
 | ||||
| // - clear prev tokens
 | ||||
| // - reset grammar
 | ||||
| void llama_sampling_reset(llama_sampling_context * ctx); | ||||
| 
 | ||||
| // Copy the sampler context
 | ||||
| void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); | ||||
| 
 | ||||
| // this is a common sampling function used across the examples for convenience
 | ||||
| // it can serve as a starting point for implementing your own sampling function
 | ||||
| // Note: When using multiple sequences, it is the caller's responsibility to call
 | ||||
| //       llama_sampling_context_reset when a sequence ends
 | ||||
| //       llama_sampling_reset when a sequence ends
 | ||||
| //
 | ||||
| // required:
 | ||||
| //  - ctx:          context to use for sampling
 | ||||
| //  - ctx_main:     context to use for sampling
 | ||||
| //  - ctx_sampling: sampling-specific context
 | ||||
| //
 | ||||
| // optional:
 | ||||
| //  - ctx_guidance:  context to use for classifier-free guidance, ignore if NULL
 | ||||
| //  - last_tokens:   needed for repetition penalty, ignore if empty
 | ||||
| //  - idx:           sample from llama_get_logits_ith(ctx, idx)
 | ||||
| //  - seq:           sequence id to associate sampler state with
 | ||||
| //  - ctx_cfg:      context to use for classifier-free guidance
 | ||||
| //  - idx:          sample from llama_get_logits_ith(ctx, idx)
 | ||||
| //
 | ||||
| // returns:
 | ||||
| //  - token:      sampled token
 | ||||
| //  - candidates: vector of candidate tokens
 | ||||
| //
 | ||||
| llama_token llama_sampling_sample( | ||||
|                   struct llama_context * ctx, | ||||
|                   struct llama_context * ctx_guidance, | ||||
|                   struct llama_sampling_context & ctx_sampling, | ||||
|         const std::vector<llama_token> & last_tokens, | ||||
|          std::vector<llama_token_data> & candidates, | ||||
|         const                      int   idx = 0, | ||||
|                           llama_seq_id   seq = 0); | ||||
|         struct llama_sampling_context * ctx_sampling, | ||||
|         struct llama_context * ctx_main, | ||||
|         struct llama_context * ctx_cfg, | ||||
|         int idx = 0); | ||||
| 
 | ||||
| void llama_sampling_accept( | ||||
|         struct llama_sampling_context * ctx_sampling, | ||||
|         struct llama_context * ctx_main, | ||||
|         llama_token id); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue