llama : sketching new sampling API
This commit is contained in:
parent
ab545c8380
commit
86b07ccbb3
2 changed files with 45 additions and 2 deletions
|
@ -61,6 +61,8 @@ extern "C" {
|
||||||
// struct llama_vocab; // TODO: add in the future
|
// struct llama_vocab; // TODO: add in the future
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
struct llama_sampler;
|
||||||
|
struct llama_constraint;
|
||||||
struct llama_sampling;
|
struct llama_sampling;
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
|
@ -412,6 +414,11 @@ extern "C" {
|
||||||
bool ignore_eos; // ignore the end-of-sequence token
|
bool ignore_eos; // ignore the end-of-sequence token
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
|
typedef struct llama_sampler_params {
|
||||||
|
bool dummy;
|
||||||
|
// TODO: add type of sampler: greedy, dist, mirostat, etc.
|
||||||
|
} llama_sampler_params;
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
struct llama_timings {
|
struct llama_timings {
|
||||||
double t_start_ms;
|
double t_start_ms;
|
||||||
|
@ -440,8 +447,10 @@ extern "C" {
|
||||||
struct llama_lora_adapter;
|
struct llama_lora_adapter;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
|
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
|
LLAMA_API struct llama_sampler_params llama_sampler_default_params(void);
|
||||||
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
|
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
|
||||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||||
|
|
||||||
|
@ -1031,7 +1040,7 @@ extern "C" {
|
||||||
int32_t length);
|
int32_t length);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling functions
|
// Sampling API
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: llama_model should become llama_vocab
|
// TODO: llama_model should become llama_vocab
|
||||||
|
@ -1156,6 +1165,32 @@ extern "C" {
|
||||||
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
|
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
|
||||||
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
|
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Sampling v2 API
|
||||||
|
//
|
||||||
|
|
||||||
|
// samplers
|
||||||
|
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params);
|
||||||
|
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
|
||||||
|
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||||
|
|
||||||
|
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||||
|
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i);
|
||||||
|
|
||||||
|
// constraints
|
||||||
|
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep);
|
||||||
|
LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep);
|
||||||
|
// ...
|
||||||
|
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||||
|
|
||||||
|
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
|
||||||
|
LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model split
|
// Model split
|
||||||
//
|
//
|
||||||
|
|
|
@ -17935,6 +17935,14 @@ struct llama_context_params llama_context_default_params() {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampler_params llama_sampler_default_params() {
|
||||||
|
struct llama_sampler_params result = {
|
||||||
|
/*.dummy =*/ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_sampling_params llama_sampling_default_params() {
|
struct llama_sampling_params llama_sampling_default_params() {
|
||||||
struct llama_sampling_params result = {
|
struct llama_sampling_params result = {
|
||||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue