llama : add llama_sampling API + move grammar in libllama

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-05 10:08:25 +03:00
parent b69a480af4
commit f648ca2cee
No known key found for this signature in database
GPG key ID: BF970631944C16B7
48 changed files with 2481 additions and 2590 deletions

View file

@ -33,16 +33,21 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
// TODO: use everywhere in the implementation
#define LLAMA_TOKEN_NULL -1
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8
#define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
#define LLAMA_MAX_SAMPLERS 16
#ifdef __cplusplus
extern "C" {
#endif
@ -53,8 +58,10 @@ extern "C" {
// TODO: show sample usage
//
// struct llama_vocab; // TODO: add in the future
struct llama_model;
struct llama_context;
struct llama_sampling;
typedef int32_t llama_pos;
typedef int32_t llama_token;
@ -199,6 +206,16 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
enum llama_sampler_type {
LLAMA_SAMPLER_TYPE_NONE = 0,
LLAMA_SAMPLER_TYPE_TOP_K = 1,
LLAMA_SAMPLER_TYPE_TOP_P = 2,
LLAMA_SAMPLER_TYPE_MIN_P = 3,
LLAMA_SAMPLER_TYPE_TFS_Z = 4,
LLAMA_SAMPLER_TYPE_TYPICAL_P = 5,
LLAMA_SAMPLER_TYPE_TEMPERATURE = 6,
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@ -206,6 +223,7 @@ extern "C" {
} llama_token_data;
typedef struct llama_token_data_array {
// TODO: consider SoA
llama_token_data * data;
size_t size;
bool sorted;
@ -300,7 +318,6 @@ extern "C" {
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggerganov/llama.cpp/pull/7544
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
@ -328,7 +345,8 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together to avoid misalignment during copy-by-value.
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
@ -356,53 +374,56 @@ extern "C" {
void * kv_overrides; // pointer to vector containing overrides
} llama_model_quantize_params;
// grammar types
struct llama_grammar;
typedef struct llama_logit_bias {
llama_token token;
float bias;
} llama_logit_bias;
// grammar element type
enum llama_gretype {
// end of rule definition
LLAMA_GRETYPE_END = 0,
// parameters for sampling the logits
typedef struct llama_sampling_params {
uint32_t seed; // the seed used to initialize llama_sampling_context
int32_t n_prev; // number of previous tokens to remember
int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k; // <= 0 to use vocab size
float top_p; // 1.0 = disabled
float min_p; // 0.0 = disabled
float tfs_z; // 1.0 = disabled
float typ_p; // typical_p, 1.0 = disabled
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range; // 0.0 = disabled
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat; // 1.0 = disabled
float penalty_freq; // 0.0 = disabled
float penalty_present; // 0.0 = disabled
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate
// start of alternate definition for rule
LLAMA_GRETYPE_ALT = 1,
// samplers
int32_t n_samplers;
enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS];
// non-terminal element: reference to rule
LLAMA_GRETYPE_RULE_REF = 2,
// terminal element: character (code point)
LLAMA_GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
LLAMA_GRETYPE_CHAR_NOT = 4,
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z])
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6,
// any character (.)
LLAMA_GRETYPE_CHAR_ANY = 7,
};
typedef struct llama_grammar_element {
enum llama_gretype type;
uint32_t value; // Unicode code point or rule ID
} llama_grammar_element;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
} llama_sampling_params;
// performance timing information
struct llama_timings {
double t_start_ms;
double t_end_ms;
double t_load_ms;
double t_sample_ms;
double t_sampling_ms;
double t_grammar_ms;
double t_accept_ms;
double t_p_eval_ms;
double t_eval_ms;
int32_t n_sample;
int32_t n_sampling;
int32_t n_grammar;
int32_t n_accept;
int32_t n_p_eval;
int32_t n_eval;
};
@ -417,8 +438,9 @@ extern "C" {
struct llama_lora_adapter;
// Helpers for getting default parameters
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
// Initialize the llama + ggml backend
@ -445,6 +467,7 @@ extern "C" {
LLAMA_API void llama_free_model(struct llama_model * model);
// TODO: rename to llama_init_from_model
LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params);
@ -460,23 +483,22 @@ extern "C" {
LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
@ -704,7 +726,7 @@ extern "C" {
//
// Returns the *actual* size in bytes of the state
// (rng, logits, embedding and kv_cache)
// (logits, embedding and kv_cache)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@ -1006,159 +1028,131 @@ extern "C" {
char * buf,
int32_t length);
//
// Grammar
//
/// Initialize a llama_grammar.
///
/// @param rules The rule elements of the grammar to initialize.
/// @param n_rules The number of rules.
/// @param start_rule_index The index of the root rule (the starting point of the grammar).
/// @return The initialized llama_grammar or nullptr if initialization failed.
LLAMA_API struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
/// @details Apply constraints from grammar
LLAMA_API void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates);
LLAMA_API DEPRECATED(void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar),
"use llama_grammar_sample instead");
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token);
//
// Sampling functions
//
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
// TODO: llama_model should become llama_vocab
LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
// - clear prev token
// - reset grammar state
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
// Sampling parameter mutation
// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
// Set the logits from which to sample.
// This call initializes the internal token candidates array.
// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
LLAMA_API void llama_sampling_set_logits(
struct llama_sampling * smpl,
const float * logits);
/// @details Returns the current candidate tokens.
LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
struct llama_sampling * smpl);
// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
// Each function can accept an array of token candidates. If the candidates are not provided, the internal
// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sampling_softmax(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sampling_top_k(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sampling_top_p(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API void llama_sampling_min_p(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sampling_tail_free(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sampling_typical(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Apply temperature and entropy
LLAMA_API void llama_sampling_temp(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Apply constraints from grammar
LLAMA_API void llama_sampling_grammar(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present);
LLAMA_API void llama_sampling_penalties(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param logits Logits extracted from the original generation context.
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sample_apply_guidance(
struct llama_context * ctx,
float * logits,
float * logits_guidance,
float scale);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(
struct llama_context * ctx,
llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k(
struct llama_context * ctx,
llama_token_data_array * candidates,
int32_t k,
size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API void llama_sample_min_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
llama_token_data_array * candidates,
float z,
size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API void llama_sample_entropy(
struct llama_context * ctx,
llama_token_data_array * candidates_p,
float min_temp,
float max_temp,
float exponent_val);
LLAMA_API void llama_sample_temp(
struct llama_context * ctx,
llama_token_data_array * candidates,
float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat(
struct llama_context * ctx,
llama_token_data_array * candidates,
float tau,
float eta,
int32_t m,
float * mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat_v2(
struct llama_context * ctx,
llama_token_data_array * candidates,
float tau,
float eta,
float * mu);
/// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
LLAMA_API llama_token llama_sampling_sample_mirostat(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Selects the token with the highest probability.
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
LLAMA_API llama_token llama_sample_token_greedy(
struct llama_context * ctx,
llama_token_data_array * candidates);
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
LLAMA_API llama_token llama_sampling_sample_greedy(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
LLAMA_API llama_token llama_sample_token(
struct llama_context * ctx,
llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probability distribution.
LLAMA_API llama_token llama_sampling_sample_dist(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
LLAMA_API llama_token llama_sampling_sample(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
/// @details Accepts the sampled token into the sampling context.
/// - adds it to "prev" tokens
/// - updates the grammar state (if apply_grammar is true)
LLAMA_API void llama_sampling_accept(
struct llama_sampling * smpl,
llama_token token,
bool apply_grammar);
/// @details Get the number of accepted tokens so far (max of n_prev)
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
/// @details Get the ith accepted token
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
LLAMA_API llama_token llama_sampling_prev(
const struct llama_sampling * smpl,
int32_t ith);
/// @details Get the last accepted token
/// Same as llama_sampling_prev(smpl, 0)
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
//
// Model split
@ -1177,8 +1171,8 @@ extern "C" {
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl);
LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
@ -1193,59 +1187,4 @@ extern "C" {
}
#endif
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL
#include <random>
#include <string>
#include <vector>
struct ggml_tensor;
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};
using llama_grammar_rule = std::vector< llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
const llama_grammar_candidates & candidates);
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
#endif // LLAMA_API_INTERNAL
#endif // LLAMA_H