From 86b07ccbb3d93779d47f067664269a429d29d263 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 12:09:08 +0300 Subject: [PATCH] llama : sketching new sampling API --- include/llama.h | 39 +++++++++++++++++++++++++++++++++++++-- src/llama.cpp | 8 ++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/include/llama.h b/include/llama.h index 099c9e746..244e07705 100644 --- a/include/llama.h +++ b/include/llama.h @@ -61,6 +61,8 @@ extern "C" { // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; + struct llama_sampler; + struct llama_constraint; struct llama_sampling; typedef int32_t llama_pos; @@ -412,6 +414,11 @@ extern "C" { bool ignore_eos; // ignore the end-of-sequence token } llama_sampling_params; + typedef struct llama_sampler_params { + bool dummy; + // TODO: add type of sampler: greedy, dist, mirostat, etc. + } llama_sampler_params; + // performance timing information struct llama_timings { double t_start_ms; @@ -440,8 +447,10 @@ extern "C" { struct llama_lora_adapter; // Helpers for getting default parameters + // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -465,7 +474,7 @@ extern "C" { LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_model_params params); + struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); @@ -1031,7 +1040,7 @@ extern "C" { int32_t length); // - // Sampling functions + // Sampling API // // TODO: llama_model should become llama_vocab @@ -1156,6 +1165,32 @@ extern "C" { /// returns LLAMA_TOKEN_NULL if there are no accepted tokens LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); + // + // Sampling v2 API + // + + // samplers + + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + + LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + + // constraints + + LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); + // ... + LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); + + LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); + LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + // // Model split // diff --git a/src/llama.cpp b/src/llama.cpp index 258a56842..d0ca96aca 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17935,6 +17935,14 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_sampler_params llama_sampler_default_params() { + struct llama_sampler_params result = { + /*.dummy =*/ false, + }; + + return result; +} + struct llama_sampling_params llama_sampling_default_params() { struct llama_sampling_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED,