cont : add penalties and logit-bias constraints [no ci]
This commit is contained in:
parent
0daebc6b8d
commit
a2ce91cbef
6 changed files with 556 additions and 119 deletions
|
@ -128,57 +128,57 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
char llama_sampling_type_to_chr(llama_sampler_type sampler) {
|
char llama_sampling_type_to_chr(llama_constraint_type sampler) {
|
||||||
switch (sampler) {
|
switch (sampler) {
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_K: return 'k';
|
case LLAMA_CONSTRAINT_TYPE_TOP_K: return 'k';
|
||||||
case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f';
|
case LLAMA_CONSTRAINT_TYPE_TFS_Z: return 'f';
|
||||||
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return 'y';
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_P: return 'p';
|
case LLAMA_CONSTRAINT_TYPE_TOP_P: return 'p';
|
||||||
case LLAMA_SAMPLER_TYPE_MIN_P: return 'm';
|
case LLAMA_CONSTRAINT_TYPE_MIN_P: return 'm';
|
||||||
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't';
|
case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return 't';
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler) {
|
std::string llama_sampling_type_to_str(llama_constraint_type sampler) {
|
||||||
switch (sampler) {
|
switch (sampler) {
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
|
case LLAMA_CONSTRAINT_TYPE_TOP_K: return "top_k";
|
||||||
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
case LLAMA_CONSTRAINT_TYPE_TFS_Z: return "tfs_z";
|
||||||
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p";
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
|
case LLAMA_CONSTRAINT_TYPE_TOP_P: return "top_p";
|
||||||
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
|
case LLAMA_CONSTRAINT_TYPE_MIN_P: return "min_p";
|
||||||
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return "temperature";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
std::vector<llama_constraint_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
std::unordered_map<std::string, llama_constraint_type> sampler_canonical_name_map {
|
||||||
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
|
{ "top_k", LLAMA_CONSTRAINT_TYPE_TOP_K },
|
||||||
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
|
{ "top_p", LLAMA_CONSTRAINT_TYPE_TOP_P },
|
||||||
{ "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ_p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
|
{ "min_p", LLAMA_CONSTRAINT_TYPE_MIN_P },
|
||||||
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
{ "tfs_z", LLAMA_CONSTRAINT_TYPE_TFS_Z },
|
||||||
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
{ "temperature", LLAMA_CONSTRAINT_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
// make it ready for both system names and input names
|
// make it ready for both system names and input names
|
||||||
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
std::unordered_map<std::string, llama_constraint_type> sampler_alt_name_map {
|
||||||
{ "top-k", LLAMA_SAMPLER_TYPE_TOP_K },
|
{ "top-k", LLAMA_CONSTRAINT_TYPE_TOP_K },
|
||||||
{ "top-p", LLAMA_SAMPLER_TYPE_TOP_P },
|
{ "top-p", LLAMA_CONSTRAINT_TYPE_TOP_P },
|
||||||
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
|
{ "nucleus", LLAMA_CONSTRAINT_TYPE_TOP_P },
|
||||||
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ "typical-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ "typical", LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ", LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
|
{ "min-p", LLAMA_CONSTRAINT_TYPE_MIN_P },
|
||||||
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
{ "tfs-z", LLAMA_CONSTRAINT_TYPE_TFS_Z },
|
||||||
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
|
{ "tfs", LLAMA_CONSTRAINT_TYPE_TFS_Z },
|
||||||
{ "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
{ "temp", LLAMA_CONSTRAINT_TYPE_TEMPERATURE },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers;
|
std::vector<llama_constraint_type> samplers;
|
||||||
samplers.reserve(names.size());
|
samplers.reserve(names.size());
|
||||||
|
|
||||||
for (const auto & name : names) {
|
for (const auto & name : names) {
|
||||||
|
@ -198,17 +198,17 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
|
||||||
return samplers;
|
return samplers;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars) {
|
std::vector<llama_constraint_type> llama_sampling_types_from_chars(const std::string & chars) {
|
||||||
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
std::unordered_map<char, llama_constraint_type> sampler_name_map {
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_K), LLAMA_CONSTRAINT_TYPE_TOP_K },
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TFS_Z), LLAMA_CONSTRAINT_TYPE_TFS_Z },
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TYPICAL_P), LLAMA_CONSTRAINT_TYPE_TYPICAL_P },
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P },
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_P), LLAMA_CONSTRAINT_TYPE_TOP_P },
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P },
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_MIN_P), LLAMA_CONSTRAINT_TYPE_MIN_P },
|
||||||
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
|
{ llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TEMPERATURE), LLAMA_CONSTRAINT_TYPE_TEMPERATURE }
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers;
|
std::vector<llama_constraint_type> samplers;
|
||||||
samplers.reserve(chars.size());
|
samplers.reserve(chars.size());
|
||||||
|
|
||||||
for (const auto & c : chars) {
|
for (const auto & c : chars) {
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct gpt_sampling_params {
|
struct gpt_sampling_params {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling
|
||||||
|
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
@ -30,13 +30,13 @@ typedef struct gpt_sampling_params {
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
|
|
||||||
std::vector<enum llama_sampler_type> samplers = {
|
std::vector<enum llama_constraint_type> samplers = {
|
||||||
LLAMA_SAMPLER_TYPE_TOP_K,
|
LLAMA_CONSTRAINT_TYPE_TOP_K,
|
||||||
LLAMA_SAMPLER_TYPE_TFS_Z,
|
LLAMA_CONSTRAINT_TYPE_TFS_Z,
|
||||||
LLAMA_SAMPLER_TYPE_TYPICAL_P,
|
LLAMA_CONSTRAINT_TYPE_TYPICAL_P,
|
||||||
LLAMA_SAMPLER_TYPE_TOP_P,
|
LLAMA_CONSTRAINT_TYPE_TOP_P,
|
||||||
LLAMA_SAMPLER_TYPE_MIN_P,
|
LLAMA_CONSTRAINT_TYPE_MIN_P,
|
||||||
LLAMA_SAMPLER_TYPE_TEMPERATURE
|
LLAMA_CONSTRAINT_TYPE_TEMPERATURE
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
@ -48,7 +48,16 @@ typedef struct gpt_sampling_params {
|
||||||
|
|
||||||
// print the samplers into a string
|
// print the samplers into a string
|
||||||
std::string print_samplers() const;
|
std::string print_samplers() const;
|
||||||
} gpt_sampling_params;
|
};
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
struct gpt_sampler {
|
||||||
|
gpt_sampling_params params;
|
||||||
|
|
||||||
|
struct llama_constraint * grmr = nullptr;
|
||||||
|
|
||||||
|
struct llama_sampler * smpl = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
// overload of llama_sampling_init using gpt_sampling_params
|
// overload of llama_sampling_init using gpt_sampling_params
|
||||||
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
|
||||||
|
@ -72,8 +81,8 @@ llama_token llama_sampling_sample(
|
||||||
// get a string representation of the last accepted tokens
|
// get a string representation of the last accepted tokens
|
||||||
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);
|
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);
|
||||||
|
|
||||||
char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type);
|
char llama_sampling_type_to_chr(enum llama_constraint_type sampler_type);
|
||||||
std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type);
|
std::string llama_sampling_type_to_str(enum llama_constraint_type sampler_type);
|
||||||
|
|
||||||
std::vector<enum llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<enum llama_constraint_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars);
|
std::vector<enum llama_constraint_type> llama_sampling_types_from_chars(const std::string & chars);
|
||||||
|
|
|
@ -46,6 +46,7 @@
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 2
|
#define LLAMA_STATE_SEQ_VERSION 2
|
||||||
|
|
||||||
|
// TODO: remove before merge
|
||||||
#define LLAMA_MAX_SAMPLERS 16
|
#define LLAMA_MAX_SAMPLERS 16
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -209,14 +210,15 @@ extern "C" {
|
||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_sampler_type {
|
// TODO: move to common, rename to gpt_constraint_type
|
||||||
LLAMA_SAMPLER_TYPE_NONE = 0,
|
enum llama_constraint_type {
|
||||||
LLAMA_SAMPLER_TYPE_TOP_K = 1,
|
LLAMA_CONSTRAINT_TYPE_NONE = 0,
|
||||||
LLAMA_SAMPLER_TYPE_TOP_P = 2,
|
LLAMA_CONSTRAINT_TYPE_TOP_K = 1,
|
||||||
LLAMA_SAMPLER_TYPE_MIN_P = 3,
|
LLAMA_CONSTRAINT_TYPE_TOP_P = 2,
|
||||||
LLAMA_SAMPLER_TYPE_TFS_Z = 4,
|
LLAMA_CONSTRAINT_TYPE_MIN_P = 3,
|
||||||
LLAMA_SAMPLER_TYPE_TYPICAL_P = 5,
|
LLAMA_CONSTRAINT_TYPE_TFS_Z = 4,
|
||||||
LLAMA_SAMPLER_TYPE_TEMPERATURE = 6,
|
LLAMA_CONSTRAINT_TYPE_TYPICAL_P = 5,
|
||||||
|
LLAMA_CONSTRAINT_TYPE_TEMPERATURE = 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
|
@ -382,6 +384,7 @@ extern "C" {
|
||||||
float bias;
|
float bias;
|
||||||
} llama_logit_bias;
|
} llama_logit_bias;
|
||||||
|
|
||||||
|
// TODO: remove before merge
|
||||||
// parameters for sampling the logits
|
// parameters for sampling the logits
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
uint32_t seed; // the seed used to initialize llama_sampling_context
|
uint32_t seed; // the seed used to initialize llama_sampling_context
|
||||||
|
@ -406,7 +409,7 @@ extern "C" {
|
||||||
|
|
||||||
// samplers
|
// samplers
|
||||||
int32_t n_samplers;
|
int32_t n_samplers;
|
||||||
enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS];
|
enum llama_constraint_type samplers[LLAMA_MAX_SAMPLERS];
|
||||||
|
|
||||||
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
bool penalize_nl; // consider newlines as a repeatable token
|
bool penalize_nl; // consider newlines as a repeatable token
|
||||||
|
@ -414,7 +417,11 @@ extern "C" {
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
typedef struct llama_sampler_params {
|
typedef struct llama_sampler_params {
|
||||||
uint32_t seed; // the seed used to initialize the rng of the sampler
|
uint32_t seed; // the seed used to initialize the rng of the sampler
|
||||||
|
|
||||||
|
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float mirostat_tau; // target entropy
|
||||||
|
float mirostat_eta; // learning rate
|
||||||
|
|
||||||
// TODO: add type of sampler: greedy, dist, mirostat, etc.
|
// TODO: add type of sampler: greedy, dist, mirostat, etc.
|
||||||
} llama_sampler_params;
|
} llama_sampler_params;
|
||||||
|
@ -1176,6 +1183,8 @@ extern "C" {
|
||||||
typedef void * llama_constraint_context_t;
|
typedef void * llama_constraint_context_t;
|
||||||
|
|
||||||
struct llama_constraint_i {
|
struct llama_constraint_i {
|
||||||
|
// TODO: add name API
|
||||||
|
|
||||||
void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL
|
void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL
|
||||||
void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required
|
void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required
|
||||||
void (*reset) (struct llama_constraint * cnstr); // can be NULL
|
void (*reset) (struct llama_constraint * cnstr); // can be NULL
|
||||||
|
@ -1184,6 +1193,8 @@ extern "C" {
|
||||||
|
|
||||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||||
//void (*apply_ggml) (struct llama_constraint * cnstr, ...);
|
//void (*apply_ggml) (struct llama_constraint * cnstr, ...);
|
||||||
|
|
||||||
|
// TODO: add API to get timing stats
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_constraint {
|
struct llama_constraint {
|
||||||
|
@ -1191,14 +1202,28 @@ extern "C" {
|
||||||
llama_constraint_context_t ctx;
|
llama_constraint_context_t ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
|
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
|
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
|
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root);
|
LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_penalties(
|
||||||
|
struct llama_model * model,
|
||||||
|
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float penalty_repeat, // 1.0 = disabled
|
||||||
|
float penalty_freq, // 0.0 = disabled
|
||||||
|
float penalty_present, // 0.0 = disabled
|
||||||
|
bool penalize_nl, // consider newlines as a repeatable token
|
||||||
|
bool ignore_eos); // ignore the end-of-sequence token
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias(
|
||||||
|
struct llama_model * model,
|
||||||
|
int32_t n_logit_bias,
|
||||||
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
// do not call if used with llama_sampler_add_constraint
|
// do not call if used with llama_sampler_add_constraint
|
||||||
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||||
|
@ -1209,19 +1234,47 @@ extern "C" {
|
||||||
|
|
||||||
// samplers
|
// samplers
|
||||||
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params);
|
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params);
|
||||||
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
|
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
|
||||||
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);
|
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits);
|
||||||
|
|
||||||
|
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
|
||||||
// TODO: should this take ownership so the user does not need to call llama_constraint_free
|
// TODO: should this take ownership so the user does not need to call llama_constraint_free
|
||||||
// or should just make a reference to the constraint so that it can be reused in multiple llama_sampler?
|
// or should just make a reference to the constraint so that it can be reused in multiple llama_sampler?
|
||||||
//
|
//
|
||||||
// seems better to take the ownership, otherwise the copying of the sampler will be more complicated
|
// seems better to take the ownership, otherwise the copying of the sampler will be more complicated
|
||||||
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||||
|
|
||||||
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i);
|
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates);
|
||||||
|
LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates);
|
||||||
|
LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Get the number of accepted tokens so far (max of n_prev)
|
||||||
|
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
/// @details Get the ith accepted token
|
||||||
|
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
|
||||||
|
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
|
||||||
|
LLAMA_API llama_token llama_sampler_prev(
|
||||||
|
const struct llama_sampler * smpl,
|
||||||
|
int32_t ith);
|
||||||
|
|
||||||
|
/// @details Get the last accepted token
|
||||||
|
/// Same as llama_sampler_prev(smpl, 0)
|
||||||
|
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
|
||||||
|
LLAMA_API llama_token llama_sampler_last(const struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
// TODO: extend in the future
|
||||||
|
//LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t i);
|
||||||
|
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model split
|
// Model split
|
||||||
|
|
|
@ -676,7 +676,14 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_top_k_i;
|
result->iface = &llama_constraint_top_k_i;
|
||||||
result->ctx = new llama_constraint_context_top_k{k, min_keep};
|
result->ctx = new llama_constraint_context_top_k;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_top_k *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.k =*/ k,
|
||||||
|
/*.min_keep =*/ min_keep,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -691,7 +698,7 @@ struct llama_constraint_context_top_p {
|
||||||
static struct llama_constraint_i llama_constraint_top_p_i = {
|
static struct llama_constraint_i llama_constraint_top_p_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_top_p * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx;
|
||||||
llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep);
|
llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep);
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
|
@ -713,7 +720,14 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_top_p_i;
|
result->iface = &llama_constraint_top_p_i;
|
||||||
result->ctx = new llama_constraint_context_top_p{p, min_keep};
|
result->ctx = new llama_constraint_context_top_p;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_top_p *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.p =*/ p,
|
||||||
|
/*.min_keep =*/ min_keep,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -728,7 +742,7 @@ struct llama_constraint_context_min_p {
|
||||||
static struct llama_constraint_i llama_constraint_min_p_i = {
|
static struct llama_constraint_i llama_constraint_min_p_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
|
||||||
llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep);
|
llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep);
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
|
@ -750,7 +764,14 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_min_p_i;
|
result->iface = &llama_constraint_min_p_i;
|
||||||
result->ctx = new llama_constraint_context_min_p{p, min_keep};
|
result->ctx = new llama_constraint_context_min_p;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_min_p *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.p =*/ p,
|
||||||
|
/*.min_keep =*/ min_keep,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -765,7 +786,7 @@ struct llama_constraint_context_tail_free {
|
||||||
static struct llama_constraint_i llama_constraint_tail_free_i = {
|
static struct llama_constraint_i llama_constraint_tail_free_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
|
||||||
llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep);
|
llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep);
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
|
@ -787,7 +808,14 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_tail_free_i;
|
result->iface = &llama_constraint_tail_free_i;
|
||||||
result->ctx = new llama_constraint_context_tail_free{z, min_keep};
|
result->ctx = new llama_constraint_context_tail_free;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_tail_free *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.z =*/ z,
|
||||||
|
/*.min_keep =*/ min_keep,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -802,7 +830,7 @@ struct llama_constraint_context_typical {
|
||||||
static struct llama_constraint_i llama_constraint_typical_i = {
|
static struct llama_constraint_i llama_constraint_typical_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_typical *) cnstr->ctx;
|
||||||
llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep);
|
llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep);
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
|
@ -824,7 +852,14 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_typical_i;
|
result->iface = &llama_constraint_typical_i;
|
||||||
result->ctx = new llama_constraint_context_typical{p, min_keep};
|
result->ctx = new llama_constraint_context_typical;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_typical *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.p =*/ p,
|
||||||
|
/*.min_keep =*/ min_keep,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -838,7 +873,7 @@ struct llama_constraint_context_temp {
|
||||||
static struct llama_constraint_i llama_constraint_temp_i = {
|
static struct llama_constraint_i llama_constraint_temp_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_temp *) cnstr->ctx;
|
||||||
llama_sampling_temp_impl(candidates, ctx->temp);
|
llama_sampling_temp_impl(candidates, ctx->temp);
|
||||||
},
|
},
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
|
@ -860,7 +895,13 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_temp_i;
|
result->iface = &llama_constraint_temp_i;
|
||||||
result->ctx = new llama_constraint_context_temp{temp};
|
result->ctx = new llama_constraint_context_temp;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_temp *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.temp =*/ temp,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -876,7 +917,7 @@ struct llama_constraint_context_temp_ext {
|
||||||
static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
static struct llama_constraint_i llama_constraint_temp_ext_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
|
||||||
if (ctx->delta > 0) {
|
if (ctx->delta > 0) {
|
||||||
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
|
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
|
||||||
const float temp_max = ctx->temp + ctx->delta;
|
const float temp_max = ctx->temp + ctx->delta;
|
||||||
|
@ -905,7 +946,15 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
|
||||||
struct llama_constraint * result = new llama_constraint;
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
result->iface = &llama_constraint_temp_ext_i;
|
result->iface = &llama_constraint_temp_ext_i;
|
||||||
result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent};
|
result->ctx = new llama_constraint_context_temp_ext;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_temp_ext *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.temp =*/ temp,
|
||||||
|
/*.delta =*/ delta,
|
||||||
|
/*.exponent =*/ exponent,
|
||||||
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -920,15 +969,20 @@ struct llama_constraint_context_grammar {
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct llama_constraint_i llama_constraint_grammar_i = {
|
static struct llama_constraint_i llama_constraint_grammar_i = {
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||||
|
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||||
|
if (ctx->grammar) {
|
||||||
|
llama_grammar_accept_impl(*ctx->grammar, token);
|
||||||
|
}
|
||||||
|
},
|
||||||
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||||
if (ctx->grammar) {
|
if (ctx->grammar) {
|
||||||
llama_sampling_grammar_impl(candidates, *ctx->grammar);
|
llama_sampling_grammar_impl(candidates, *ctx->grammar);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
/* .reset = */ [](struct llama_constraint * cnstr) {
|
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||||
llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
|
||||||
if (ctx->grammar) {
|
if (ctx->grammar) {
|
||||||
llama_grammar_free_impl(ctx->grammar);
|
llama_grammar_free_impl(ctx->grammar);
|
||||||
ctx->grammar = nullptr;
|
ctx->grammar = nullptr;
|
||||||
|
@ -973,20 +1027,173 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
|
||||||
auto * ctx = (llama_constraint_context_grammar *) result->ctx;
|
auto * ctx = (llama_constraint_context_grammar *) result->ctx;
|
||||||
|
|
||||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||||
ctx->grammar_str = grammar_str;
|
*ctx = {
|
||||||
ctx->grammar_root = grammar_root;
|
/*.grammar_str = */ grammar_str,
|
||||||
|
/*.grammar_root = */ grammar_root,
|
||||||
ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root);
|
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
ctx->grammar_str.clear();
|
*ctx = {
|
||||||
ctx->grammar_root.clear();
|
/*.grammar_str = */ {},
|
||||||
|
/*.grammar_root = */ {},
|
||||||
ctx->grammar = nullptr;
|
/*.grammar = */ nullptr,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// penalties
|
||||||
|
|
||||||
|
struct llama_constraint_context_penalties {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
|
int32_t penalty_last_n;
|
||||||
|
float penalty_repeat;
|
||||||
|
float penalty_freq;
|
||||||
|
float penalty_present;
|
||||||
|
|
||||||
|
bool penalize_nl;
|
||||||
|
bool ignore_eos;
|
||||||
|
|
||||||
|
ring_buffer<llama_token> prev;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct llama_constraint_i llama_constraint_penalties_i = {
|
||||||
|
/* .accept = */ [](struct llama_constraint * cnstr, llama_token token) {
|
||||||
|
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
|
||||||
|
ctx->prev.push_back(token);
|
||||||
|
},
|
||||||
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
|
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
|
||||||
|
|
||||||
|
GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary");
|
||||||
|
|
||||||
|
if (ctx->ignore_eos) {
|
||||||
|
candidates->data[ctx->vocab->special_eos_id].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((ctx->penalty_last_n == 0) ||
|
||||||
|
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float nl_logit = !ctx->penalize_nl ? candidates->data[ctx->vocab->linefeed_id].logit : -INFINITY;
|
||||||
|
|
||||||
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
|
// TODO: optimize this by maintaining the token count in the constraint context
|
||||||
|
llama_token_cnt token_count;
|
||||||
|
for (int i = 0; i < ctx->penalty_last_n; ++i) {
|
||||||
|
token_count[ctx->prev.rat(i)]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampling_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);
|
||||||
|
|
||||||
|
if (!ctx->penalize_nl) {
|
||||||
|
// restore the logit of the newline token if it was penalized
|
||||||
|
candidates->data[ctx->vocab->linefeed_id].logit = nl_logit;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/* .reset = */ [](struct llama_constraint * cnstr) {
|
||||||
|
auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx;
|
||||||
|
ctx->prev.clear();
|
||||||
|
},
|
||||||
|
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
|
||||||
|
cnstr->ctx = new llama_constraint_context_penalties;
|
||||||
|
const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr_src->ctx;
|
||||||
|
auto * ctx_dst = ( llama_constraint_context_penalties *) cnstr->ctx;
|
||||||
|
|
||||||
|
*ctx_dst = *ctx_src;
|
||||||
|
},
|
||||||
|
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||||
|
if (cnstr->ctx) {
|
||||||
|
delete (llama_constraint_context_penalties *) cnstr->ctx;
|
||||||
|
}
|
||||||
|
delete cnstr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) {
|
||||||
|
GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL);
|
||||||
|
GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL);
|
||||||
|
|
||||||
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
|
result->iface = &llama_constraint_penalties_i;
|
||||||
|
result->ctx = new llama_constraint_context_penalties;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_penalties *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.vocab = */ &vocab,
|
||||||
|
/*.penalty_last_n = */ penalty_last_n,
|
||||||
|
/*.penalty_repeat = */ penalty_repeat,
|
||||||
|
/*.penalty_freq = */ penalty_freq,
|
||||||
|
/*.penalty_present = */ penalty_present,
|
||||||
|
/*.penalize_nl = */ penalize_nl,
|
||||||
|
/*.ignore_eos = */ ignore_eos,
|
||||||
|
/*.prev = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// logit-bias
|
||||||
|
|
||||||
|
struct llama_constraint_context_logit_bias {
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
|
std::vector<llama_logit_bias> logit_bias;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct llama_constraint_i llama_constraint_logit_bias_i = {
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
|
||||||
|
auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx;
|
||||||
|
|
||||||
|
GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary");
|
||||||
|
|
||||||
|
for (const auto & lb : ctx->logit_bias) {
|
||||||
|
candidates->data[lb.token].logit += lb.bias;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
/* .reset = */ nullptr,
|
||||||
|
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
|
||||||
|
cnstr->ctx = new llama_constraint_context_logit_bias;
|
||||||
|
const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr_src->ctx;
|
||||||
|
auto * ctx_dst = ( llama_constraint_context_logit_bias *) cnstr->ctx;
|
||||||
|
|
||||||
|
*ctx_dst = *ctx_src;
|
||||||
|
},
|
||||||
|
/* .free = */ [](struct llama_constraint * cnstr) {
|
||||||
|
if (cnstr->ctx) {
|
||||||
|
delete (llama_constraint_context_logit_bias *) cnstr->ctx;
|
||||||
|
}
|
||||||
|
delete cnstr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_logit_bias_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
int32_t n_logit_bias,
|
||||||
|
const llama_logit_bias * logit_bias) {
|
||||||
|
struct llama_constraint * result = new llama_constraint;
|
||||||
|
|
||||||
|
result->iface = &llama_constraint_logit_bias_i;
|
||||||
|
result->ctx = new llama_constraint_context_logit_bias;
|
||||||
|
|
||||||
|
auto * ctx = (llama_constraint_context_logit_bias *) result->ctx;
|
||||||
|
|
||||||
|
*ctx = {
|
||||||
|
/*.vocab = */ &vocab,
|
||||||
|
/*.logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////
|
||||||
|
|
||||||
void llama_constraint_free_impl(struct llama_constraint * cnstr) {
|
void llama_constraint_free_impl(struct llama_constraint * cnstr) {
|
||||||
if (cnstr->iface->free && cnstr) {
|
if (cnstr->iface->free && cnstr) {
|
||||||
cnstr->iface->free(cnstr);
|
cnstr->iface->free(cnstr);
|
||||||
|
@ -1012,10 +1219,11 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) {
|
||||||
|
|
||||||
// samplers
|
// samplers
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_impl(struct llama_sampler_params params) {
|
struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) {
|
||||||
auto * result = new llama_sampler;
|
auto * result = new llama_sampler;
|
||||||
|
|
||||||
result->params = params;
|
result->params = params;
|
||||||
|
result->vocab = &vocab;
|
||||||
|
|
||||||
result->rng.seed(params.seed);
|
result->rng.seed(params.seed);
|
||||||
|
|
||||||
|
@ -1075,3 +1283,22 @@ void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
||||||
llama_constraint_accept_impl(*cnstr, token);
|
llama_constraint_accept_impl(*cnstr, token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * candidates) {
|
||||||
|
for (auto * cnstr : smpl.constraints) {
|
||||||
|
llama_constraint_apply_impl(*cnstr, candidates);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) {
|
||||||
|
if (ith < 0 || ith >= (int) smpl.prev.size()) {
|
||||||
|
return LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return smpl.prev.rat(ith);
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
|
||||||
|
return smpl.prev.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ struct llama_grammar;
|
||||||
|
|
||||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||||
|
|
||||||
|
// TODO: remove before merge
|
||||||
struct llama_sampling {
|
struct llama_sampling {
|
||||||
llama_sampling(const struct llama_vocab & vocab);
|
llama_sampling(const struct llama_vocab & vocab);
|
||||||
~llama_sampling();
|
~llama_sampling();
|
||||||
|
@ -27,7 +28,7 @@ struct llama_sampling {
|
||||||
|
|
||||||
const struct llama_vocab & vocab;
|
const struct llama_vocab & vocab;
|
||||||
|
|
||||||
std::vector<llama_sampler_type> samplers;
|
std::vector<llama_constraint_type> samplers;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
@ -120,7 +121,25 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t
|
||||||
struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep);
|
struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep);
|
||||||
struct llama_constraint * llama_constraint_init_temp_impl (float t);
|
struct llama_constraint * llama_constraint_init_temp_impl (float t);
|
||||||
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
|
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
|
||||||
struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
|
|
||||||
|
struct llama_constraint * llama_constraint_init_grammar_impl (
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root);
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_penalties_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
int32_t penalty_last_n,
|
||||||
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
|
float penalty_present,
|
||||||
|
bool penalize_nl,
|
||||||
|
bool ignore_eos);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
int32_t n_logit_bias,
|
||||||
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
void llama_constraint_free_impl(struct llama_constraint * cnstr);
|
void llama_constraint_free_impl(struct llama_constraint * cnstr);
|
||||||
|
|
||||||
|
@ -133,15 +152,22 @@ void llama_constraint_reset_impl (struct llama_constraint & cnstr);
|
||||||
struct llama_sampler {
|
struct llama_sampler {
|
||||||
llama_sampler_params params;
|
llama_sampler_params params;
|
||||||
|
|
||||||
|
const struct llama_vocab * vocab;
|
||||||
|
|
||||||
// state
|
// state
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
// TODO: move to a standalone penalty constraint?
|
float mirostat_mu;
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
std::vector<llama_constraint *> constraints;
|
std::vector<llama_constraint *> constraints;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
|
llama_token_data_array cur_p;
|
||||||
|
|
||||||
// timing
|
// timing
|
||||||
|
|
||||||
mutable int64_t t_sample_us = 0;
|
mutable int64_t t_sample_us = 0;
|
||||||
|
@ -149,10 +175,15 @@ struct llama_sampler {
|
||||||
mutable int32_t n_sample = 0;
|
mutable int32_t n_sample = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_impl ( struct llama_sampler_params params);
|
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
|
||||||
void llama_sampler_free_impl ( struct llama_sampler * smpl);
|
void llama_sampler_free_impl ( struct llama_sampler * smpl);
|
||||||
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
|
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
|
||||||
void llama_sampler_reset_impl( struct llama_sampler & smpl);
|
void llama_sampler_reset_impl( struct llama_sampler & smpl);
|
||||||
|
|
||||||
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
||||||
void llama_sampler_accept_impl (struct llama_sampler & smpl, llama_token token);
|
|
||||||
|
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
|
||||||
|
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
|
||||||
|
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
|
||||||
|
|
139
src/llama.cpp
139
src/llama.cpp
|
@ -17938,6 +17938,9 @@ struct llama_context_params llama_context_default_params() {
|
||||||
struct llama_sampler_params llama_sampler_default_params() {
|
struct llama_sampler_params llama_sampler_default_params() {
|
||||||
struct llama_sampler_params result = {
|
struct llama_sampler_params result = {
|
||||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||||
|
/*.mirostat =*/ 0,
|
||||||
|
/*.mirostat_tau =*/ 5.00f,
|
||||||
|
/*.mirostat_eta =*/ 0.10f,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -17965,7 +17968,7 @@ struct llama_sampling_params llama_sampling_default_params() {
|
||||||
/*.mirostat_tau =*/ 5.00f,
|
/*.mirostat_tau =*/ 5.00f,
|
||||||
/*.mirostat_eta =*/ 0.10f,
|
/*.mirostat_eta =*/ 0.10f,
|
||||||
/*.n_samplers =*/ 3,
|
/*.n_samplers =*/ 3,
|
||||||
/*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, },
|
/*.samplers =*/ { LLAMA_CONSTRAINT_TYPE_TEMPERATURE, LLAMA_CONSTRAINT_TYPE_TOP_K, LLAMA_CONSTRAINT_TYPE_TOP_P, },
|
||||||
/*.penalize_nl =*/ false,
|
/*.penalize_nl =*/ false,
|
||||||
/*.ignore_eos =*/ false,
|
/*.ignore_eos =*/ false,
|
||||||
};
|
};
|
||||||
|
@ -20916,12 +20919,12 @@ llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data
|
||||||
} else {
|
} else {
|
||||||
for (const auto & sampler : smpl->samplers) {
|
for (const auto & sampler : smpl->samplers) {
|
||||||
switch (sampler) {
|
switch (sampler) {
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
|
case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
|
||||||
case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
|
case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
|
||||||
case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
|
case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
|
||||||
case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
|
case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
|
||||||
case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
|
case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
|
||||||
case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
|
case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
|
||||||
default : break;
|
default : break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21007,6 +21010,24 @@ struct llama_constraint * llama_constraint_init_grammar(struct llama_model * mod
|
||||||
return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_constraint * llama_constraint_init_penalties(
|
||||||
|
struct llama_model * model,
|
||||||
|
int32_t penalty_last_n,
|
||||||
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
|
float penalty_present,
|
||||||
|
bool penalize_nl,
|
||||||
|
bool ignore_eos) {
|
||||||
|
return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias(
|
||||||
|
struct llama_model * model,
|
||||||
|
int32_t n_logit_bias,
|
||||||
|
const llama_logit_bias * logit_bias) {
|
||||||
|
return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_constraint_free(struct llama_constraint * cnstr) {
|
void llama_constraint_free(struct llama_constraint * cnstr) {
|
||||||
if (cnstr == nullptr) {
|
if (cnstr == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -21027,8 +21048,8 @@ void llama_constraint_reset(struct llama_constraint * cnstr) {
|
||||||
llama_constraint_reset_impl(*cnstr);
|
llama_constraint_reset_impl(*cnstr);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) {
|
struct llama_sampler * llama_sampler_init(const struct llama_model * model, struct llama_sampler_params params) {
|
||||||
return llama_sampler_init_impl(params);
|
return llama_sampler_init_impl(model->vocab, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_free(struct llama_sampler * smpl) {
|
void llama_sampler_free(struct llama_sampler * smpl) {
|
||||||
|
@ -21047,6 +21068,22 @@ void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||||
llama_sampler_reset_impl(*smpl);
|
llama_sampler_reset_impl(*smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) {
|
||||||
|
const int n_vocab = smpl->vocab->n_vocab;
|
||||||
|
|
||||||
|
smpl->cur.resize(n_vocab);
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false };
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl) {
|
||||||
|
return &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
||||||
llama_sampler_add_constraint_impl(*smpl, cnstr);
|
llama_sampler_add_constraint_impl(*smpl, cnstr);
|
||||||
}
|
}
|
||||||
|
@ -21055,10 +21092,90 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
llama_sampler_accept_impl(*smpl, token);
|
llama_sampler_accept_impl(*smpl, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) {
|
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) {
|
||||||
GGML_ABORT("not implemented");
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
llama_sampler_apply_impl(*smpl, candidates);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto type = smpl->params.mirostat;
|
||||||
|
|
||||||
|
llama_token res;
|
||||||
|
|
||||||
|
if (type == 1) {
|
||||||
|
res = llama_sampling_sample_mirostat_impl(candidates,
|
||||||
|
smpl->rng,
|
||||||
|
smpl->params.mirostat_tau,
|
||||||
|
smpl->params.mirostat_eta,
|
||||||
|
100,
|
||||||
|
smpl->vocab->n_vocab,
|
||||||
|
smpl->mirostat_mu);
|
||||||
|
} else if (type == 2) {
|
||||||
|
res = llama_sampling_sample_mirostat_v2_impl(candidates,
|
||||||
|
smpl->rng,
|
||||||
|
smpl->params.mirostat_tau,
|
||||||
|
smpl->params.mirostat_eta,
|
||||||
|
smpl->mirostat_mu);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("invalid mirostat type: %d", type);
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = llama_sampling_sample_greedy_impl(candidates);
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * candidates) {
|
||||||
|
time_meas tm(smpl->t_sample_us);
|
||||||
|
|
||||||
|
if (candidates == nullptr) {
|
||||||
|
candidates = &smpl->cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
|
||||||
|
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_sampler_n_prev(const struct llama_sampler * smpl) {
|
||||||
|
return llama_sampler_n_prev_impl(*smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith) {
|
||||||
|
return llama_sampler_prev_impl(*smpl, ith);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampler_last(const struct llama_sampler * smpl) {
|
||||||
|
return llama_sampler_prev_impl(*smpl, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) {
|
||||||
|
// GGML_ABORT("not implemented");
|
||||||
|
//}
|
||||||
|
|
||||||
//
|
//
|
||||||
// model split
|
// model split
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue