speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
c67fe68e41
commit
0e89203b51
21 changed files with 737 additions and 578 deletions
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include "llama.h"
|
||||
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
@ -34,75 +36,64 @@ typedef struct llama_sampling_params {
|
|||
|
||||
} llama_sampling_params;
|
||||
|
||||
// per-sequence sampler context
|
||||
typedef struct llama_sampler_sequence_context {
|
||||
float mirostat_mu; // mirostat sampler state
|
||||
llama_grammar * grammar;
|
||||
} llama_sampler_sequence_context;
|
||||
|
||||
// general sampler context
|
||||
typedef struct llama_sampling_context {
|
||||
~llama_sampling_context();
|
||||
|
||||
// parameters that will be used for sampling and when creating
|
||||
// new llama_sampler_sequence_context instances
|
||||
// TODO: move to llama.h
|
||||
struct llama_sampling_context {
|
||||
// parameters that will be used for sampling
|
||||
llama_sampling_params params;
|
||||
|
||||
// map of sequence ids to sampler contexts
|
||||
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
|
||||
// mirostat sampler state
|
||||
float mirostat_mu;
|
||||
|
||||
// when non-NULL, new instances of llama_sampler_sequence_context
|
||||
// will get a copy of the grammar here
|
||||
// note: only the pointer is stored here, it is not a copy of
|
||||
// the grammar and shouldn't be freed
|
||||
llama_grammar * grammar;
|
||||
} llama_sampling_context;
|
||||
|
||||
// internal
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
llama_sampling_context llama_sampling_context_init(
|
||||
const struct gpt_params & params,
|
||||
llama_grammar * grammar = NULL);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
|
||||
|
||||
// Fetches the sampler context for the specified sequence id (defaults to 0).
|
||||
// If the context for that sequence id doesn't already exist, it will be created with
|
||||
// default values based on the parameters in the ctx_sampling argument.
|
||||
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
|
||||
llama_sampling_context & ctx_sampling,
|
||||
const llama_seq_id seq = 0);
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
// Reset the sampler context for the supplied sequence id (defaults to 0).
|
||||
// This is necessary to reuse a sequence id or free memory used by sequences
|
||||
// that are no longer required.
|
||||
bool llama_sampling_context_reset(
|
||||
llama_sampling_context & ctx_sampling,
|
||||
const llama_seq_id seq = 0);
|
||||
// Reset the sampler context
|
||||
// - clear prev tokens
|
||||
// - reset grammar
|
||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||
|
||||
// Copy the sampler context
|
||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
||||
|
||||
// this is a common sampling function used across the examples for convenience
|
||||
// it can serve as a starting point for implementing your own sampling function
|
||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||
// llama_sampling_context_reset when a sequence ends
|
||||
// llama_sampling_reset when a sequence ends
|
||||
//
|
||||
// required:
|
||||
// - ctx: context to use for sampling
|
||||
// - ctx_main: context to use for sampling
|
||||
// - ctx_sampling: sampling-specific context
|
||||
//
|
||||
// optional:
|
||||
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
|
||||
// - last_tokens: needed for repetition penalty, ignore if empty
|
||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||
// - seq: sequence id to associate sampler state with
|
||||
// - ctx_cfg: context to use for classifier-free guidance
|
||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||
//
|
||||
// returns:
|
||||
// - token: sampled token
|
||||
// - candidates: vector of candidate tokens
|
||||
//
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_context * ctx,
|
||||
struct llama_context * ctx_guidance,
|
||||
struct llama_sampling_context & ctx_sampling,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
const int idx = 0,
|
||||
llama_seq_id seq = 0);
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx = 0);
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue