From 784a64404093afb9dae64a8f44803141d9789087 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 17:13:15 +0300 Subject: [PATCH] sampler : API to iterate constraints ggml-ci --- common/sampling.cpp | 39 ++++++++++------------- common/sampling.h | 8 ++--- examples/batched.swift/Sources/main.swift | 6 ++-- examples/batched/batched.cpp | 6 ++-- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 17 +++++----- include/llama.h | 9 ++++-- src/llama-sampling.cpp | 14 +++++++- src/llama-sampling.h | 4 ++- src/llama.cpp | 12 +++++-- 10 files changed, 69 insertions(+), 48 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b528d4929..718001844 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -12,7 +12,7 @@ struct gpt_sampler { struct llama_sampler * smpl; }; -std::string gpt_sampler_params::print_all() const { +std::string gpt_sampler_params::print() const { char result[1024]; snprintf(result, sizeof(result), @@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const { return std::string(result); } -std::string gpt_sampler_params::print_constraints() const { - std::string result = "CFG -> Penalties "; - if (mirostat == 0) { - for (const auto & cnstr : constraints) { - const auto name = gpt_constraint_type_to_str(cnstr); - if (!name.empty()) { - result += "-> " + name + " "; - } - } - } else { - result += "-> mirostat "; +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { + std::string result = "\tlogits"; + + for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) { + const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i); + result += " -> " + std::string(cnstr->iface->name(cnstr)) + " "; } return result; @@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st for (const auto & cnstr : params.constraints) { switch (cnstr) { case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; default: GGML_ASSERT(false && "unknown constraint type"); } } } else if (params.mirostat == 1) { - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } diff --git a/common/sampling.h b/common/sampling.h index bab264937..8ec745999 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -54,10 +54,7 @@ struct gpt_sampler_params { std::vector logit_bias; // logit biases to apply // print the parameters into a string - std::string print_all() const; - - // print the constraints into a string - std::string print_constraints() const; + std::string print() const; }; // gpt_sampler extends llama_sampler with additional functionality: @@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da // helpers +// print the constraints into a string +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); + // get a string representation of the last accepted tokens std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 380040e57..6b9f3e0d5 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -61,9 +61,9 @@ defer { llama_sampler_free(smpl) } -llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1)); -llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1)); -llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); let n_ctx = llama_n_ctx(context) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 0f35f6cd5..cbab4b66b 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_init(model, sparams); - llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); - llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 9f9f81a7f..3895b586e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -301,7 +301,7 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str()); + LOG_TEE("sampling: \n%s\n", sparams.print().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1b706efbc..85dea9782 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -457,8 +457,15 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); - LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str()); + + smpl = gpt_sampler_init(model, sparams); + if (!smpl) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + + LOG_TEE("sampling params: \n%s\n", sparams.print().c_str()); + LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -525,12 +532,6 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - smpl = gpt_sampler_init(model, sparams); - if (!smpl) { - fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); - exit(1); - } - if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); diff --git a/include/llama.h b/include/llama.h index 8c2a2aff9..813d854ef 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1012,7 +1012,7 @@ extern "C" { // The llama_sampler object contains the entire sampling information: // // - RNG state (seed and generator) - // - Custom set of constraints (see llama_sampler_add_constraint) + // - Custom set of constraints (see llama_sampler_constraint_add) // - Sampling method (greedy, dist) // - Previous tokens // @@ -1083,7 +1083,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); - // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint) + // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); @@ -1102,7 +1102,10 @@ extern "C" { LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); // important: takes ownership of the constraint object and will free it in llama_sampler_free - LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr); + LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl); + LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 385f7bec1..81cc357db 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { // TODO: should we reset the timings? } -void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { +void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { smpl.constraints.push_back(cnstr); } +int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) { + return smpl.constraints.size(); +} + +struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.constraints.size()) { + return nullptr; + } + + return smpl.constraints[ith]; +} + void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { smpl.prev.push_back(token); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index aad9f311a..bf5f596f7 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); void llama_sampler_reset_impl( struct llama_sampler & smpl); -void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); +void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); +int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); +struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p); diff --git a/src/llama.cpp b/src/llama.cpp index 8f6503152..6426073eb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20729,8 +20729,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp return &smpl->cur_p; } -void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { - llama_sampler_add_constraint_impl(*smpl, cnstr); +void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) { + llama_sampler_constraint_add_impl(*smpl, cnstr); +} + +int llama_sampler_n_constraints (const struct llama_sampler * smpl) { + return llama_sampler_n_constraints_impl(*smpl); +} + +struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) { + return llama_sampler_constraint_get_impl(*smpl, i); } void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {