llama : sketching new sampling API

This commit is contained in:
Georgi Gerganov 2024-09-03 12:09:08 +03:00
parent ab545c8380
commit 86b07ccbb3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 45 additions and 2 deletions

View file

@ -61,6 +61,8 @@ extern "C" {
// struct llama_vocab; // TODO: add in the future
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_constraint;
struct llama_sampling;
typedef int32_t llama_pos;
@ -412,6 +414,11 @@ extern "C" {
bool ignore_eos; // ignore the end-of-sequence token
} llama_sampling_params;
typedef struct llama_sampler_params {
bool dummy;
// TODO: add type of sampler: greedy, dist, mirostat, etc.
} llama_sampler_params;
// performance timing information
struct llama_timings {
double t_start_ms;
@ -440,8 +447,10 @@ extern "C" {
struct llama_lora_adapter;
// Helpers for getting default parameters
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_sampler_params llama_sampler_default_params(void);
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
@ -465,7 +474,7 @@ extern "C" {
LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params);
struct llama_model_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
@ -1031,7 +1040,7 @@ extern "C" {
int32_t length);
//
// Sampling functions
// Sampling API
//
// TODO: llama_model should become llama_vocab
@ -1156,6 +1165,32 @@ extern "C" {
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
//
// Sampling v2 API
//
// samplers
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i);
// constraints
LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep);
// ...
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates);
//
// Model split
//

View file

@ -17935,6 +17935,14 @@ struct llama_context_params llama_context_default_params() {
return result;
}
struct llama_sampler_params llama_sampler_default_params() {
struct llama_sampler_params result = {
/*.dummy =*/ false,
};
return result;
}
struct llama_sampling_params llama_sampling_default_params() {
struct llama_sampling_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,