diff --git a/common/common.cpp b/common/common.cpp index 23d171a4d..f7095c7f3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -841,15 +841,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.defrag_thold = std::stof(argv[i]); return true; } - if (arg == "--samplers") { + if (arg == "--samplers" || arg == "--constraints") { CHECK_ARG - const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers = llama_sampling_types_from_names(sampler_names, true); + const auto constraint_names = string_split(argv[i], ';'); + sparams.constraints = gpt_constraint_types_from_names(constraint_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.samplers = llama_sampling_types_from_chars(argv[i]); + sparams.constraints = gpt_constraint_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -1706,13 +1706,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const auto & sparams = params.sparams; - std::string sampler_type_chars; - std::string sampler_type_names; - for (const auto & sampler : sparams.samplers) { - sampler_type_chars += llama_sampling_type_to_chr(sampler); - sampler_type_names += llama_sampling_type_to_str(sampler) + ";"; + std::string constraint_type_chars; + std::string constraint_type_names; + for (const auto & constraint : sparams.constraints) { + constraint_type_chars += gpt_constraint_type_to_chr(constraint); + constraint_type_names += gpt_constraint_type_to_str(constraint) + ";"; } - sampler_type_names.pop_back(); + constraint_type_names.pop_back(); struct option_info { LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) @@ -1826,9 +1826,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "sampling" }); options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" - "(default: %s)", sampler_type_names.c_str() }); + "(default: %s)", constraint_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", - "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); + "simplified sequence for samplers that will be used (default: %s)", constraint_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); diff --git a/common/common.h b/common/common.h index 1c4eae34a..3a6c8e0b5 100644 --- a/common/common.h +++ b/common/common.h @@ -118,7 +118,7 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - struct gpt_sampling_params sparams; + struct gpt_sampler_params sparams; std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding diff --git a/common/sampling.cpp b/common/sampling.cpp index a98117cf8..a5e76dfd4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,7 +2,7 @@ #include "common.h" -std::string gpt_sampling_params::print_all() const { +std::string gpt_sampler_params::print_all() const { char result[1024]; snprintf(result, sizeof(result), @@ -16,11 +16,11 @@ std::string gpt_sampling_params::print_all() const { return std::string(result); } -std::string gpt_sampling_params::print_samplers() const { +std::string gpt_sampler_params::print_constraints() const { std::string result = "CFG -> Penalties "; if (mirostat == 0) { - for (const auto & sampler : samplers) { - const auto name = llama_sampling_type_to_str(sampler); + for (const auto & cnstr : constraints) { + const auto name = gpt_constraint_type_to_str(cnstr); if (!name.empty()) { result += "-> " + name + " "; } @@ -32,66 +32,159 @@ std::string gpt_sampling_params::print_samplers() const { return result; } -struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { - llama_sampling_params lparams = llama_sampling_default_params(); +struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { + gpt_sampler * result = new gpt_sampler(); - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; - lparams.n_probs = params.n_probs; - lparams.min_keep = params.min_keep; - lparams.top_k = params.top_k; - lparams.top_p = params.top_p; - lparams.min_p = params.min_p; - lparams.tfs_z = params.tfs_z; - lparams.typ_p = params.typ_p; - lparams.temp = params.temp; - lparams.dynatemp_range = params.dynatemp_range; - lparams.dynatemp_exponent = params.dynatemp_exponent; - lparams.penalty_last_n = params.penalty_last_n; - lparams.penalty_repeat = params.penalty_repeat; - lparams.penalty_freq = params.penalty_freq; - lparams.penalty_present = params.penalty_present; - lparams.mirostat = params.mirostat; - lparams.mirostat_tau = params.mirostat_tau; - lparams.mirostat_eta = params.mirostat_eta; - lparams.penalize_nl = params.penalize_nl; - lparams.ignore_eos = params.ignore_eos; + llama_sampler_params lparams = llama_sampler_default_params(); - lparams.n_samplers = params.samplers.size(); - for (int i = 0; i < lparams.n_samplers; i++) { - lparams.samplers[i] = params.samplers[i]; + lparams.seed = params.seed; + lparams.mirostat = params.mirostat; + lparams.mirostat_tau = params.mirostat_tau; + lparams.mirostat_eta = params.mirostat_eta; + + result->smpl = llama_sampler_init(model, lparams); + + llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data())); + + llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos)); + + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } } - struct llama_sampling * result = llama_sampling_init(model, lparams); - - llama_sampling_set_grammar (result, params.grammar.c_str(), "root"); - llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data()); + result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"); return result; } -void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) { - if (dst) { - llama_sampling_free(dst); - } +void gpt_sampler_free(struct gpt_sampler * gsmpl) { + if (gsmpl) { + llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->smpl); - dst = llama_sampling_cp(src); + delete gsmpl; + } } -llama_token llama_sampling_sample( - struct llama_sampling * smpl, +struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { + gpt_sampler * result = new gpt_sampler(); + + result->grmr = llama_constraint_cp(gsmpl->grmr); + result->smpl = llama_sampler_cp(gsmpl->smpl); + + return result; +} + +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { + if (apply_grammar) { + llama_constraint_accept(gsmpl->grmr, token); + } + + llama_sampler_accept(gsmpl->smpl, token); +} + +void gpt_sampler_reset (struct gpt_sampler * gsmpl) { + llama_constraint_reset(gsmpl->grmr); + + llama_sampler_reset(gsmpl->smpl); +} + +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { + return llama_sampler_last(gsmpl->smpl); +} + +static llama_token gpt_sampler_sample( + struct llama_sampler * smpl, + struct llama_token_data_array * cur_p, + float temp, + int mirostat, + int n_probs) { + GGML_ASSERT(cur_p != nullptr && "candidates array must be provided"); + + llama_token res = 0; + + if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { + // greedy sampling, with probs + res = llama_sampler_sample_greedy(smpl, cur_p, true); + } else if (temp == 0.0f) { + // greedy sampling, no probs + res = llama_sampler_sample_greedy(smpl, cur_p, false); + } else { + llama_sampler_apply(smpl, cur_p); + + if (mirostat != 0) { + res = llama_sampler_sample_mirostat(smpl, cur_p); + } else { + res = llama_sampler_sample_dist(smpl, cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} + + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); + } + } + + return res; +} + +llama_token gpt_sampler_sample( + struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + const auto & params = gsmpl->params; + + auto & grmr = gsmpl->grmr; + auto & smpl = gsmpl->smpl; + + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); // first, sample the token without any grammar constraints - const llama_token id = llama_sampling_sample(smpl, nullptr); + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); // create an array with a single token data element for the sampled id llama_token_data single_token_data = { id, 1.0f, 0.0f }; llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - llama_sampling_grammar(smpl, &single_token_data_array); + llama_constraint_apply(grmr, &single_token_data_array); // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -100,15 +193,18 @@ llama_token llama_sampling_sample( } // if the token is not valid, sample again, after applying the grammar constraints - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + auto * cur_p = llama_sampler_get_candidates(smpl); - llama_sampling_grammar(smpl, nullptr); + llama_constraint_apply(grmr, cur_p); - return llama_sampling_sample(smpl, nullptr); + return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); } -std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) { - n = std::min(n, llama_sampling_n_prev(smpl)); +std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { + auto & smpl = gsmpl->smpl; + + n = std::min(n, llama_sampler_n_prev(smpl)); if (n <= 0) { return ""; @@ -118,7 +214,7 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab for (int i = n - 1; i >= 0; i--) { - const llama_token id = llama_sampling_prev(smpl, i); + const llama_token id = llama_sampler_prev(smpl, i); GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); @@ -128,95 +224,95 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m return result; } -char llama_sampling_type_to_chr(llama_constraint_type sampler) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: return 'k'; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: return 'f'; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; - case LLAMA_CONSTRAINT_TYPE_TOP_P: return 'p'; - case LLAMA_CONSTRAINT_TYPE_MIN_P: return 'm'; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return 't'; +char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: return 'k'; + case GPT_CONSTRAINT_TYPE_TFS_Z: return 'f'; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; + case GPT_CONSTRAINT_TYPE_TOP_P: return 'p'; + case GPT_CONSTRAINT_TYPE_MIN_P: return 'm'; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string llama_sampling_type_to_str(llama_constraint_type sampler) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: return "top_k"; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; - case LLAMA_CONSTRAINT_TYPE_TOP_P: return "top_p"; - case LLAMA_CONSTRAINT_TYPE_MIN_P: return "min_p"; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; +std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: return "top_k"; + case GPT_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; + case GPT_CONSTRAINT_TYPE_TOP_P: return "top_p"; + case GPT_CONSTRAINT_TYPE_MIN_P: return "min_p"; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "top_k", LLAMA_CONSTRAINT_TYPE_TOP_K }, - { "top_p", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "typ_p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "min_p", LLAMA_CONSTRAINT_TYPE_MIN_P }, - { "tfs_z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "temperature", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, +std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map constraint_canonical_name_map { + { "top_k", GPT_CONSTRAINT_TYPE_TOP_K }, + { "top_p", GPT_CONSTRAINT_TYPE_TOP_P }, + { "typ_p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "min_p", GPT_CONSTRAINT_TYPE_MIN_P }, + { "tfs_z", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "temperature", GPT_CONSTRAINT_TYPE_TEMPERATURE }, }; - // since samplers names are written multiple ways + // since constraints names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", LLAMA_CONSTRAINT_TYPE_TOP_K }, - { "top-p", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "nucleus", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "typical-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typical", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "min-p", LLAMA_CONSTRAINT_TYPE_MIN_P }, - { "tfs-z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "tfs", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "temp", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, + std::unordered_map constraint_alt_name_map { + { "top-k", GPT_CONSTRAINT_TYPE_TOP_K }, + { "top-p", GPT_CONSTRAINT_TYPE_TOP_P }, + { "nucleus", GPT_CONSTRAINT_TYPE_TOP_P }, + { "typical-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typical", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "min-p", GPT_CONSTRAINT_TYPE_MIN_P }, + { "tfs-z", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "tfs", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "temp", GPT_CONSTRAINT_TYPE_TEMPERATURE }, }; - std::vector samplers; - samplers.reserve(names.size()); + std::vector constraints; + constraints.reserve(names.size()); for (const auto & name : names) { - auto sampler = sampler_canonical_name_map.find(name); - if (sampler != sampler_canonical_name_map.end()) { - samplers.push_back(sampler->second); + auto constraint = constraint_canonical_name_map.find(name); + if (constraint != constraint_canonical_name_map.end()) { + constraints.push_back(constraint->second); } else { if (allow_alt_names) { - sampler = sampler_alt_name_map.find(name); - if (sampler != sampler_alt_name_map.end()) { - samplers.push_back(sampler->second); + constraint = constraint_alt_name_map.find(name); + if (constraint != constraint_alt_name_map.end()) { + constraints.push_back(constraint->second); } } } } - return samplers; + return constraints; } -std::vector llama_sampling_types_from_chars(const std::string & chars) { - std::unordered_map sampler_name_map { - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_K), LLAMA_CONSTRAINT_TYPE_TOP_K }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TFS_Z), LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TYPICAL_P), LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_P), LLAMA_CONSTRAINT_TYPE_TOP_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_MIN_P), LLAMA_CONSTRAINT_TYPE_MIN_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TEMPERATURE), LLAMA_CONSTRAINT_TYPE_TEMPERATURE } +std::vector gpt_constraint_types_from_chars(const std::string & chars) { + std::unordered_map constraint_name_map { + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_K), GPT_CONSTRAINT_TYPE_TOP_K }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TFS_Z), GPT_CONSTRAINT_TYPE_TFS_Z }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TYPICAL_P), GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_P), GPT_CONSTRAINT_TYPE_TOP_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_MIN_P), GPT_CONSTRAINT_TYPE_MIN_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TEMPERATURE), GPT_CONSTRAINT_TYPE_TEMPERATURE } }; - std::vector samplers; - samplers.reserve(chars.size()); + std::vector constraints; + constraints.reserve(chars.size()); for (const auto & c : chars) { - const auto sampler = sampler_name_map.find(c); - if (sampler != sampler_name_map.end()) { - samplers.push_back(sampler->second); + const auto constraint = constraint_name_map.find(c); + if (constraint != constraint_name_map.end()) { + constraints.push_back(constraint->second); } } - return samplers; + return constraints; } diff --git a/common/sampling.h b/common/sampling.h index 365b7639a..4efa4a17c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -5,13 +5,23 @@ #include #include +enum gpt_constraint_type { + GPT_CONSTRAINT_TYPE_NONE = 0, + GPT_CONSTRAINT_TYPE_TOP_K = 1, + GPT_CONSTRAINT_TYPE_TOP_P = 2, + GPT_CONSTRAINT_TYPE_MIN_P = 3, + GPT_CONSTRAINT_TYPE_TFS_Z = 4, + GPT_CONSTRAINT_TYPE_TYPICAL_P = 5, + GPT_CONSTRAINT_TYPE_TEMPERATURE = 6, +}; + // sampling parameters -struct gpt_sampling_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling +struct gpt_sampler_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t min_keep = 0; // 0 = disabled, otherwise constraints should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -30,13 +40,13 @@ struct gpt_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector samplers = { - LLAMA_CONSTRAINT_TYPE_TOP_K, - LLAMA_CONSTRAINT_TYPE_TFS_Z, - LLAMA_CONSTRAINT_TYPE_TYPICAL_P, - LLAMA_CONSTRAINT_TYPE_TOP_P, - LLAMA_CONSTRAINT_TYPE_MIN_P, - LLAMA_CONSTRAINT_TYPE_TEMPERATURE + std::vector constraints = { + GPT_CONSTRAINT_TYPE_TOP_K, + GPT_CONSTRAINT_TYPE_TFS_Z, + GPT_CONSTRAINT_TYPE_TYPICAL_P, + GPT_CONSTRAINT_TYPE_TOP_P, + GPT_CONSTRAINT_TYPE_MIN_P, + GPT_CONSTRAINT_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -46,23 +56,30 @@ struct gpt_sampling_params { // print the parameters into a string std::string print_all() const; - // print the samplers into a string - std::string print_samplers() const; + // print the constraints into a string + std::string print_constraints() const; }; -// TODO: implement struct gpt_sampler { - gpt_sampling_params params; + gpt_sampler_params params; struct llama_constraint * grmr = nullptr; struct llama_sampler * smpl = nullptr; }; -// overload of llama_sampling_init using gpt_sampling_params -struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); +// llama_sampler API overload -void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); +struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); + +void gpt_sampler_free(struct gpt_sampler * gsmpl); + +struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); + +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); +void gpt_sampler_reset (struct gpt_sampler * gsmpl); + +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); // common sampling implementation: // @@ -71,18 +88,18 @@ void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token llama_sampling_sample( - struct llama_sampling * smpl, - struct llama_context * ctx, - int idx); +llama_token gpt_sampler_sample( + struct gpt_sampler * gsmpl, + struct llama_context * ctx, + int idx); // helpers // get a string representation of the last accepted tokens -std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); +std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); -char llama_sampling_type_to_chr(enum llama_constraint_type sampler_type); -std::string llama_sampling_type_to_str(enum llama_constraint_type sampler_type); +char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr); +std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & chars); +std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector gpt_constraint_types_from_chars(const std::string & chars); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 296c1c687..88202b800 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling ** g_smpl; +static gpt_sampler ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, *g_smpl); + llama_print_timings(*g_ctx, (*g_smpl)->smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -193,7 +193,7 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling * smpl = nullptr; + gpt_sampler * smpl = nullptr; std::vector chat_msgs; @@ -458,7 +458,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); - LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str()); + LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -525,7 +525,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - smpl = llama_sampling_init(model, sparams); + smpl = gpt_sampler_init(model, sparams); if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -681,9 +681,9 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(smpl, ctx, -1); + const llama_token id = gpt_sampler_sample(smpl, ctx, -1); - llama_sampling_accept(smpl, id, /* apply_grammar= */ true); + gpt_sampler_accept(smpl, id, /* apply_grammar= */ true); // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); @@ -704,7 +704,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); + gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -747,7 +747,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev); + const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -769,7 +769,7 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = llama_sampling_last(smpl); + llama_token last_token = gpt_sampler_last(smpl); for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { @@ -786,7 +786,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, llama_sampling_last(smpl))) { + if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { LOG("found an EOG token\n"); if (params.interactive) { @@ -807,7 +807,7 @@ int main(int argc, char ** argv) { // if current token is not EOG, we add it to current assistant message if (params.conversation) { - auto id = llama_sampling_last(smpl); + const auto id = gpt_sampler_last(smpl); assistant_ss << llama_token_to_piece(ctx, id, false); } @@ -903,7 +903,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(smpl); + gpt_sampler_reset(smpl); } is_interacting = false; } @@ -928,13 +928,14 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx, smpl); + llama_print_timings(ctx, smpl->smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + gpt_sampler_free(smpl); + llama_free(ctx); llama_free_model(model); - llama_sampling_free(smpl); llama_backend_free(); ggml_threadpool_free(threadpool); diff --git a/include/llama.h b/include/llama.h index 4dd5348a8..920952d68 100644 --- a/include/llama.h +++ b/include/llama.h @@ -63,7 +63,6 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_sampling; // TODO: remove before merge typedef int32_t llama_pos; typedef int32_t llama_token; @@ -210,17 +209,6 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; - // TODO: move to common, rename to gpt_constraint_type - enum llama_constraint_type { - LLAMA_CONSTRAINT_TYPE_NONE = 0, - LLAMA_CONSTRAINT_TYPE_TOP_K = 1, - LLAMA_CONSTRAINT_TYPE_TOP_P = 2, - LLAMA_CONSTRAINT_TYPE_MIN_P = 3, - LLAMA_CONSTRAINT_TYPE_TFS_Z = 4, - LLAMA_CONSTRAINT_TYPE_TYPICAL_P = 5, - LLAMA_CONSTRAINT_TYPE_TEMPERATURE = 6, - }; - typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -384,38 +372,6 @@ extern "C" { float bias; } llama_logit_bias; - // TODO: remove before merge - // parameters for sampling the logits - typedef struct llama_sampling_params { - uint32_t seed; // the seed used to initialize llama_sampling_context - int32_t n_prev; // number of previous tokens to remember - int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k; // <= 0 to use vocab size - float top_p; // 1.0 = disabled - float min_p; // 0.0 = disabled - float tfs_z; // 1.0 = disabled - float typ_p; // typical_p, 1.0 = disabled - float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range; // 0.0 = disabled - float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat; // 1.0 = disabled - float penalty_freq; // 0.0 = disabled - float penalty_present; // 0.0 = disabled - int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau; // target entropy - float mirostat_eta; // learning rate - - // samplers - int32_t n_samplers; - enum llama_constraint_type samplers[LLAMA_MAX_SAMPLERS]; - - // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. - bool penalize_nl; // consider newlines as a repeatable token - bool ignore_eos; // ignore the end-of-sequence token - } llama_sampling_params; - typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler @@ -432,14 +388,10 @@ extern "C" { double t_end_ms; double t_load_ms; double t_sampling_ms; - double t_grammar_ms; - double t_accept_ms; double t_p_eval_ms; double t_eval_ms; int32_t n_sampling; - int32_t n_grammar; - int32_t n_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -458,7 +410,6 @@ extern "C" { LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); - LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -1052,126 +1003,126 @@ extern "C" { // // TODO: llama_model should become llama_vocab - LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); + //LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); - LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); + //LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - // Copies the internal state of the sampler (rng, prev, params, grammar, etc.) - LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); + //// Copies the internal state of the sampler (rng, prev, params, grammar, etc.) + //LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - // - clear prev token - // - reset grammar state - LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); + //// - clear prev token + //// - reset grammar state + //LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); - // Sampling parameter mutation - // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable - LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); - LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); + //// Sampling parameter mutation + //// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable + //LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + //LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - // Set the logits from which to sample. - // This call initializes the internal token candidates array. - // The internal candidates are implicitly used by the sampling API below when no candidates are provided. - LLAMA_API void llama_sampling_set_logits( - struct llama_sampling * smpl, - const float * logits); + //// Set the logits from which to sample. + //// This call initializes the internal token candidates array. + //// The internal candidates are implicitly used by the sampling API below when no candidates are provided. + //LLAMA_API void llama_sampling_set_logits( + // struct llama_sampling * smpl, + // const float * logits); - /// @details Returns the current candidate tokens. - LLAMA_API llama_token_data_array * llama_sampling_get_candidates( - struct llama_sampling * smpl); + ///// @details Returns the current candidate tokens. + //LLAMA_API llama_token_data_array * llama_sampling_get_candidates( + // struct llama_sampling * smpl); - // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. - // Each function can accept an array of token candidates. If the candidates are not provided, the internal - // candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). + //// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. + //// Each function can accept an array of token candidates. If the candidates are not provided, the internal + //// candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sampling_softmax( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + //LLAMA_API void llama_sampling_softmax( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sampling_top_k( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + //LLAMA_API void llama_sampling_top_k( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sampling_top_p( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + //LLAMA_API void llama_sampling_top_p( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sampling_min_p( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + //LLAMA_API void llama_sampling_min_p( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sampling_tail_free( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + //LLAMA_API void llama_sampling_tail_free( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sampling_typical( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + //LLAMA_API void llama_sampling_typical( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Apply temperature and entropy - LLAMA_API void llama_sampling_temp( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Apply temperature and entropy + //LLAMA_API void llama_sampling_temp( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Apply constraints from grammar - LLAMA_API void llama_sampling_grammar( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Apply constraints from grammar + //LLAMA_API void llama_sampling_grammar( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sampling_penalties( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + ///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + //LLAMA_API void llama_sampling_penalties( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - LLAMA_API llama_token llama_sampling_sample_mirostat( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + //LLAMA_API llama_token llama_sampling_sample_mirostat( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. - LLAMA_API llama_token llama_sampling_sample_greedy( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Selects the token with the highest probability. + ///// Does not compute the token probabilities. Use llama_sampling_softmax() instead. + //LLAMA_API llama_token llama_sampling_sample_greedy( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probability distribution. - LLAMA_API llama_token llama_sampling_sample_dist( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Randomly selects a token from the candidates based on their probability distribution. + //LLAMA_API llama_token llama_sampling_sample_dist( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). - LLAMA_API llama_token llama_sampling_sample( - struct llama_sampling * smpl, - llama_token_data_array * candidates); + ///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). + //LLAMA_API llama_token llama_sampling_sample( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); - /// @details Accepts the sampled token into the sampling context. - /// - adds it to "prev" tokens - /// - updates the grammar state (if apply_grammar is true) - LLAMA_API void llama_sampling_accept( - struct llama_sampling * smpl, - llama_token token, - bool apply_grammar); + ///// @details Accepts the sampled token into the sampling context. + ///// - adds it to "prev" tokens + ///// - updates the grammar state (if apply_grammar is true) + //LLAMA_API void llama_sampling_accept( + // struct llama_sampling * smpl, + // llama_token token, + // bool apply_grammar); - /// @details Get the number of accepted tokens so far (max of n_prev) - LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); + ///// @details Get the number of accepted tokens so far (max of n_prev) + //LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); - /// @details Get the ith accepted token - /// @param ith [0, n_prev), ith == 0 is the last accepted token. - /// returns LLAMA_TOKEN_NULL if ith is out of bounds - LLAMA_API llama_token llama_sampling_prev( - const struct llama_sampling * smpl, - int32_t ith); + ///// @details Get the ith accepted token + ///// @param ith [0, n_prev), ith == 0 is the last accepted token. + ///// returns LLAMA_TOKEN_NULL if ith is out of bounds + //LLAMA_API llama_token llama_sampling_prev( + // const struct llama_sampling * smpl, + // int32_t ith); - /// @details Get the last accepted token - /// Same as llama_sampling_prev(smpl, 0) - /// returns LLAMA_TOKEN_NULL if there are no accepted tokens - LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); + ///// @details Get the last accepted token + ///// Same as llama_sampling_prev(smpl, 0) + ///// returns LLAMA_TOKEN_NULL if there are no accepted tokens + //LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); // // Sampling v2 API @@ -1204,11 +1155,11 @@ extern "C" { struct llama_constraint_i { // TODO: add name API - void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required - void (*reset) (struct llama_constraint * cnstr); // can be NULL - void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); // can be NULL if ctx is NULL - void (*free) (struct llama_constraint * cnstr); // can be NULL + void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL + void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * candidates); // required + void (*reset) ( struct llama_constraint * cnstr); // can be NULL + struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL + void (*free) ( struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1228,21 +1179,27 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + + LLAMA_API struct llama_constraint * llama_constraint_init_grammar( + const struct llama_model * model, + const char * grammar_str, + const char * grammar_root); LLAMA_API struct llama_constraint * llama_constraint_init_penalties( - struct llama_model * model, - int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat, // 1.0 = disabled - float penalty_freq, // 0.0 = disabled - float penalty_present, // 0.0 = disabled - bool penalize_nl, // consider newlines as a repeatable token - bool ignore_eos); // ignore the end-of-sequence token + const struct llama_model * model, + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present, // 0.0 = disabled + bool penalize_nl, // consider newlines as a repeatable token + bool ignore_eos); // ignore the end-of-sequence token LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + + LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); @@ -1273,7 +1230,7 @@ extern "C" { LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates); LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates); - LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs); LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates); /// @details Get the number of accepted tokens so far (max of n_prev) @@ -1310,8 +1267,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl); - LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f21b5fd55..7a1f8a805 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -24,107 +24,7 @@ static void llama_log_softmax(float * array, size_t size) { } } -llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) { -} - -llama_sampling::~llama_sampling() { - if (grammar) { - llama_grammar_free_impl(grammar); - } -} - -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { - auto * result = new llama_sampling(vocab); - - result->params = params; - - result->prev = ring_buffer(params.n_prev); - - for (int i = 0; i < params.n_samplers; ++i) { - result->samplers.push_back(params.samplers[i]); - } - - llama_sampling_set_rng_seed_impl(*result, params.seed); - - return result; -} - -void llama_sampling_free_impl(struct llama_sampling * sampling) { - if (sampling == nullptr) { - return; - } - - delete sampling; -} - -struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { - auto * result = new llama_sampling(smpl.vocab); - - result->params = smpl.params; - - result->grammar_str = smpl.grammar_str; - result->grammar_root = smpl.grammar_root; - - result->logit_bias = smpl.logit_bias; - - if (smpl.grammar) { - result->grammar = llama_grammar_cp_impl(*smpl.grammar); - } - - result->rng = smpl.rng; - result->prev = smpl.prev; - - return result; -} - -void llama_sampling_reset_impl(struct llama_sampling & smpl) { - if (smpl.grammar) { - llama_grammar_free_impl(smpl.grammar); - smpl.grammar = nullptr; - } - - if (!smpl.grammar_str.empty()) { - smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); - } - - smpl.prev.clear(); -} - -void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - - smpl.rng.seed(seed); -} - -void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { - if (smpl.grammar) { - llama_grammar_free_impl(smpl.grammar); - smpl.grammar = nullptr; - } - - if (grammar_str != nullptr && grammar_str[0] != '\0') { - smpl.grammar_str = grammar_str; - smpl.grammar_root = grammar_root; - - smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root); - } else { - smpl.grammar_str.clear(); - smpl.grammar_root.clear(); - } -} - -void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - smpl.logit_bias.clear(); - smpl.logit_bias.reserve(n_logit_bias); - - for (int32_t i = 0; i < n_logit_bias; ++i) { - smpl.logit_bias.push_back(logit_bias[i]); - } -} - -void llama_sampling_softmax_impl(llama_token_data_array * candidates) { +void llama_constraint_softmax_impl(llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order @@ -149,7 +49,7 @@ void llama_sampling_softmax_impl(llama_token_data_array * candidates) { } } -void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; @@ -226,12 +126,12 @@ void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, s candidates->size = k; } -void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -252,7 +152,7 @@ void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, siz candidates->size = last_idx; } -void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -307,12 +207,12 @@ void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, siz } } -void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_constraint_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -361,7 +261,7 @@ void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, candidates->size = last_idx; } -void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -369,7 +269,7 @@ void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, s } // Compute the softmax of logits and calculate entropy - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -419,7 +319,7 @@ void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, s candidates->sorted = false; } -void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { +void llama_constraint_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -428,7 +328,7 @@ void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_ // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -482,17 +382,17 @@ void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_ #endif } -void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) { +void llama_constraint_temp_impl(llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } } -void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { +void llama_constraint_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { llama_grammar_apply_impl(grammar, candidates); } -void llama_sampling_penalties_impl( +void llama_constraint_penalties_impl( llama_token_data_array * candidates, const llama_token_cnt & token_count, float penalty_repeat, @@ -521,8 +421,8 @@ void llama_sampling_penalties_impl( candidates->sorted = false; } -llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_constraint_softmax_impl(candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -541,8 +441,8 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sampling_top_k_impl(candidates, int(k), 1); - llama_token X = llama_sampling_sample_dist_impl(candidates, rng); + llama_constraint_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -557,8 +457,8 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * return X; } -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_constraint_softmax_impl(candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -570,10 +470,10 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array } // Normalize the probabilities of the remaining words - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Sample the next word X from the remaining words - llama_token X = llama_sampling_sample_dist_impl(candidates, rng); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -589,8 +489,16 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array return X; } -llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) { - // Find max element +llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { + if (probs) { + // if probs are needed, we apply softmax to get the probabilities + llama_constraint_softmax_impl(candidates); + + // the candidates are sorted, so we can just return the first one + return candidates->data[0].id; + } + + // return the token with the highest logit auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); @@ -600,8 +508,8 @@ llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidate return result; } -llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_constraint_softmax_impl(candidates); std::vector probs; probs.reserve(candidates->size); @@ -618,26 +526,6 @@ llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * cand return result; } -void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) { - smpl.prev.push_back(token); - - if (apply_grammar && smpl.grammar) { - llama_grammar_accept_impl(*smpl.grammar, token); - } -} - -llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.prev.size()) { - return LLAMA_TOKEN_NULL; - } - - return smpl.prev.rat(ith); -} - -int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { - return smpl.prev.size(); -} - // // sampling v2 // @@ -655,14 +543,12 @@ static struct llama_constraint_i llama_constraint_top_k_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_sampling_top_k_impl(candidates, ctx->k, ctx->min_keep); + llama_constraint_top_k_impl(candidates, ctx->k, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_top_k; - const auto * ctx_src = (const llama_constraint_context_top_k *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_top_k *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; + return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -695,14 +581,12 @@ static struct llama_constraint_i llama_constraint_top_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; - llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_top_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_top_p; - const auto * ctx_src = (const llama_constraint_context_top_p *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_top_p *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx; + return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -735,14 +619,12 @@ static struct llama_constraint_i llama_constraint_min_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; - llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_min_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_min_p; - const auto * ctx_src = (const llama_constraint_context_min_p *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_min_p *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx; + return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -775,14 +657,12 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; - llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); + llama_constraint_tail_free_impl(candidates, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_tail_free; - const auto * ctx_src = (const llama_constraint_context_tail_free *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_tail_free *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx; + return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -815,14 +695,12 @@ static struct llama_constraint_i llama_constraint_typical_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; - llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_typical_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_typical; - const auto * ctx_src = (const llama_constraint_context_typical *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_typical *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx; + return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -854,14 +732,12 @@ static struct llama_constraint_i llama_constraint_temp_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; - llama_sampling_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(candidates, ctx->temp); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_temp; - const auto * ctx_src = (const llama_constraint_context_temp *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_temp *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx; + return llama_constraint_init_temp_impl(ctx->temp); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -898,17 +774,15 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; - llama_sampling_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); + llama_constraint_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); } else { - llama_sampling_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(candidates, ctx->temp); } }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_temp_ext; - const auto * ctx_src = (const llama_constraint_context_temp_ext *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_temp_ext *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx; + return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -950,7 +824,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { - llama_sampling_grammar_impl(candidates, *ctx->grammar); + llama_constraint_grammar_impl(candidates, *ctx->grammar); } }, /* .reset = */ [](struct llama_constraint * cnstr) { @@ -964,18 +838,19 @@ static struct llama_constraint_i llama_constraint_grammar_i = { ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); } }, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_grammar; - const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_grammar *) cnstr->ctx; - - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; + auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr); + auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; if (ctx_src->grammar) { + ctx_dst->grammar_str = ctx_src->grammar_str; + ctx_dst->grammar_root = ctx_src->grammar_root; + ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar); - } else { - ctx_dst->grammar = nullptr; } + + return result; }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1059,7 +934,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { token_count[ctx->prev.rat(i)]++; } - llama_sampling_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + llama_constraint_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); if (!ctx->penalize_nl) { // restore the logit of the newline token if it was penalized @@ -1070,12 +945,21 @@ static struct llama_constraint_i llama_constraint_penalties_i = { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.clear(); }, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_penalties; - const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_penalties *) cnstr->ctx; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx; + auto * result = llama_constraint_init_penalties_impl( + *ctx_src->vocab, + ctx_src->penalty_last_n, + ctx_src->penalty_repeat, + ctx_src->penalty_freq, + ctx_src->penalty_present, + ctx_src->penalize_nl, + ctx_src->ignore_eos); - *ctx_dst = *ctx_src; + auto * ctx_dst = (llama_constraint_context_penalties *) result->ctx; + ctx_dst->prev = ctx_src->prev; + + return result; }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1126,12 +1010,9 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { } }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_logit_bias; - const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_logit_bias *) cnstr->ctx; - - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx; + return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1158,6 +1039,10 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( //////////////////////////////////////// +struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr) { + return cnstr.iface->copy ? cnstr.iface->copy(&cnstr) : nullptr; +} + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); @@ -1214,12 +1099,14 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) // copy the constraints objects result->constraints.clear(); for (const auto & cnstr : smpl.constraints) { - result->constraints.push_back(new llama_constraint); - result->constraints.back()->iface = cnstr->iface; - - if (cnstr->ctx) { + if (cnstr->ctx == nullptr) { + result->constraints.push_back(new llama_constraint { + /* .iface = */ cnstr->iface, + /* .ctx = */ nullptr, + }); + } else { GGML_ASSERT(cnstr->iface->copy); - result->constraints.back()->iface->copy(result->constraints.back(), cnstr); + result->constraints.push_back(cnstr->iface->copy(cnstr)); } } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 7de37c89e..dd9236392 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,74 +10,17 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; -// TODO: remove before merge -struct llama_sampling { - llama_sampling(const struct llama_vocab & vocab); - ~llama_sampling(); +void llama_constraint_softmax_impl (struct llama_token_data_array * candidates); +void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); +void llama_constraint_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_constraint_temp_impl (struct llama_token_data_array * candidates, float temp); +void llama_constraint_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); - llama_sampling_params params; - - std::string grammar_str; - std::string grammar_root; - - std::vector logit_bias; // logit biases to apply - - // state - - std::mt19937 rng; - - const struct llama_vocab & vocab; - - std::vector samplers; - - ring_buffer prev; - - struct llama_grammar * grammar = nullptr; - - // mirostat sampler state - float mirostat_mu; - - mutable int64_t t_sample_us = 0; - mutable int64_t t_grammar_us = 0; - mutable int64_t t_accept_us = 0; - - mutable int32_t n_sample = 0; - mutable int32_t n_grammar = 0; - mutable int32_t n_accept = 0; - - std::vector cur; - - llama_token_data_array cur_p; -}; - -// -// internal API -// - -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params); - -void llama_sampling_free_impl(struct llama_sampling * sampling); - -struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); - -void llama_sampling_reset_impl(struct llama_sampling & smpl); - -// TODO: move the API below as member functions of llama_sampling -void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); -void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); -void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - -void llama_sampling_softmax_impl (struct llama_token_data_array * candidates); -void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp); -void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); - -void llama_sampling_penalties_impl( +void llama_constraint_penalties_impl( llama_token_data_array * candidates, const llama_token_cnt & token_count, float penalty_repeat, @@ -90,22 +33,18 @@ void llama_sampling_penalties_impl( /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); +llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); -llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates); -llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); +llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); -void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar); - -llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); -int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); // @@ -141,6 +80,8 @@ struct llama_constraint * llama_constraint_init_penalties_impl( int32_t n_logit_bias, const llama_logit_bias * logit_bias); +struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr); + void llama_constraint_free_impl(struct llama_constraint * cnstr); void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index 4060fa1de..a40fc4c30 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17946,36 +17946,6 @@ struct llama_sampler_params llama_sampler_default_params() { return result; } -struct llama_sampling_params llama_sampling_default_params() { - struct llama_sampling_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_prev =*/ 64, - /*.n_probs =*/ 0, - /*.min_keep =*/ 0, - /*.top_k =*/ 40, - /*.top_p =*/ 0.95f, - /*.min_p =*/ 0.05f, - /*.tfs_z =*/ 1.00f, - /*.typ_p =*/ 1.00f, - /*.temp =*/ 0.80f, - /*.dynatemp_range =*/ 0.00f, - /*.dynatemp_exponent =*/ 1.00f, - /*.penalty_last_n =*/ 64, - /*.penalty_repeat =*/ 1.00f, - /*.penalty_freq =*/ 0.00f, - /*.penalty_present =*/ 0.00f, - /*.mirostat =*/ 0, - /*.mirostat_tau =*/ 5.00f, - /*.mirostat_eta =*/ 0.10f, - /*.n_samplers =*/ 3, - /*.samplers =*/ { LLAMA_CONSTRAINT_TYPE_TEMPERATURE, LLAMA_CONSTRAINT_TYPE_TOP_K, LLAMA_CONSTRAINT_TYPE_TOP_P, }, - /*.penalize_nl =*/ false, - /*.ignore_eos =*/ false, - }; - - return result; -} - struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -20638,341 +20608,341 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { - return llama_sampling_init_impl(model->vocab, params); -} +//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { +// return llama_sampling_init_impl(model->vocab, params); +//} -void llama_sampling_free(struct llama_sampling * smpl) { - if (smpl == nullptr) { - return; - } +//void llama_sampling_free(struct llama_sampling * smpl) { +// if (smpl == nullptr) { +// return; +// } - llama_sampling_free_impl(smpl); -} +// llama_sampling_free_impl(smpl); +//} -struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { - return llama_sampling_cp_impl(*smpl); -} +//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { +// return llama_sampling_cp_impl(*smpl); +//} -void llama_sampling_reset(struct llama_sampling * smpl) { - llama_sampling_reset_impl(*smpl); -} +//void llama_sampling_reset(struct llama_sampling * smpl) { +// llama_sampling_reset_impl(*smpl); +//} -void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { - llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); -} +//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { +// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); +//} -void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); -} +//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { +// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); +//} -void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { - const int n_vocab = smpl->vocab.n_vocab; +//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { +// const int n_vocab = smpl->vocab.n_vocab; - smpl->cur.resize(n_vocab); +// smpl->cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } +// for (llama_token token_id = 0; token_id < n_vocab; token_id++) { +// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; +// } - for (const auto & lb : smpl->logit_bias) { - smpl->cur[lb.token].logit += lb.bias; - } +// for (const auto & lb : smpl->logit_bias) { +// smpl->cur[lb.token].logit += lb.bias; +// } - if (smpl->params.ignore_eos) { - smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; - } +// if (smpl->params.ignore_eos) { +// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; +// } - smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; +// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; - // apply penalties - { - const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; +// // apply penalties +// { +// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; - llama_sampling_penalties(smpl, &smpl->cur_p); +// llama_sampling_penalties(smpl, &smpl->cur_p); - if (!smpl->params.penalize_nl) { - for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { - if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { - smpl->cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } -} +// if (!smpl->params.penalize_nl) { +// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { +// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { +// smpl->cur_p.data[idx].logit = nl_logit; +// break; +// } +// } +// } +// } +//} -llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { - return &smpl->cur_p; -} +//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { +// return &smpl->cur_p; +//} -void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_softmax_impl(candidates); -} +// llama_sampling_softmax_impl(candidates); +//} -void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); -} +// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); +//} -void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); -} +// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); +//} -void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); -} +// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); +//} -void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); -} +// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); +//} -void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); -} +// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); +//} -void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - if (smpl->params.dynatemp_range > 0) { - const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); - const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); +// if (smpl->params.dynatemp_range > 0) { +// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); +// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); - llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); - } else { - llama_sampling_temp_impl(candidates, smpl->params.temp); - } -} +// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); +// } else { +// llama_sampling_temp_impl(candidates, smpl->params.temp); +// } +//} -void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_grammar_us); +//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_grammar_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - if (smpl->grammar) { - llama_sampling_grammar_impl(candidates, *smpl->grammar); +// if (smpl->grammar) { +// llama_sampling_grammar_impl(candidates, *smpl->grammar); - smpl->n_grammar++; - } -} +// smpl->n_grammar++; +// } +//} -void llama_sampling_penalties( - struct llama_sampling * smpl, - llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_penalties( +// struct llama_sampling * smpl, +// llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); +// const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); - const float penalty_repeat = smpl->params.penalty_repeat; - const float penalty_freq = smpl->params.penalty_freq; - const float penalty_present = smpl->params.penalty_present; +// const float penalty_repeat = smpl->params.penalty_repeat; +// const float penalty_freq = smpl->params.penalty_freq; +// const float penalty_present = smpl->params.penalty_present; - if ((penalty_last_n == 0) || - (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } +// if ((penalty_last_n == 0) || +// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { +// return; +// } - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: move to sampling state and avoid reallocation - llama_token_cnt token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[smpl->prev.rat(i)]++; - } +// // Create a frequency map to count occurrences of each token in last_tokens +// // TODO: move to sampling state and avoid reallocation +// llama_token_cnt token_count; +// for (size_t i = 0; i < penalty_last_n; ++i) { +// token_count[smpl->prev.rat(i)]++; +// } - llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); -} +// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); +//} -llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - const auto type = smpl->params.mirostat; +// const auto type = smpl->params.mirostat; - llama_token res; +// llama_token res; - if (type == 1) { - res = llama_sampling_sample_mirostat_impl(candidates, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - 100, - smpl->vocab.n_vocab, - smpl->mirostat_mu); - } else if (type == 2) { - res = llama_sampling_sample_mirostat_v2_impl(candidates, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - smpl->mirostat_mu); - } else { - GGML_ABORT("invalid mirostat type: %d", type); - } - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); - - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } - - auto res = llama_sampling_sample_greedy_impl(candidates); - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); - - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } - - auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); - - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } - - const auto & params = smpl->params; - - const float temp = params.temp; - const int mirostat = params.mirostat; - - auto & cur_p = candidates; - - llama_token res = 0; - - if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { - // greedy sampling, with probs - llama_sampling_softmax_impl(cur_p); - res = cur_p->data[0].id; - } else if (temp == 0.0f) { - // greedy sampling, no probs - res = llama_sampling_sample_greedy(smpl, cur_p); - } else { - if (mirostat != 0) { - llama_sampling_temp(smpl, cur_p); - res = llama_sampling_sample_mirostat(smpl, cur_p); - } else { - for (const auto & sampler : smpl->samplers) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; - default : break; - } - } - - res = llama_sampling_sample_dist(smpl, cur_p); - - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); - - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} - - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } - } - - smpl->n_sample++; +// if (type == 1) { +// res = llama_sampling_sample_mirostat_impl(candidates, +// smpl->rng, +// smpl->params.mirostat_tau, +// smpl->params.mirostat_eta, +// 100, +// smpl->vocab.n_vocab, +// smpl->mirostat_mu); +// } else if (type == 2) { +// res = llama_sampling_sample_mirostat_v2_impl(candidates, +// smpl->rng, +// smpl->params.mirostat_tau, +// smpl->params.mirostat_eta, +// smpl->mirostat_mu); +// } else { +// GGML_ABORT("invalid mirostat type: %d", type); +// } + +// smpl->n_sample++; + +// return res; +//} + +//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); + +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } + +// auto res = llama_sampling_sample_greedy_impl(candidates); + +// smpl->n_sample++; + +// return res; +//} + +//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); + +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } + +// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + +// smpl->n_sample++; + +// return res; +//} + +//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); + +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } + +// const auto & params = smpl->params; + +// const float temp = params.temp; +// const int mirostat = params.mirostat; + +// auto & cur_p = candidates; + +// llama_token res = 0; + +// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { +// // greedy sampling, with probs +// llama_sampling_softmax_impl(cur_p); +// res = cur_p->data[0].id; +// } else if (temp == 0.0f) { +// // greedy sampling, no probs +// res = llama_sampling_sample_greedy(smpl, cur_p); +// } else { +// if (mirostat != 0) { +// llama_sampling_temp(smpl, cur_p); +// res = llama_sampling_sample_mirostat(smpl, cur_p); +// } else { +// for (const auto & sampler : smpl->samplers) { +// switch (sampler) { +// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; +// default : break; +// } +// } + +// res = llama_sampling_sample_dist(smpl, cur_p); + +// //{ +// // const int n_top = 10; +// // LOG("top %d candidates:\n", n_top); + +// // for (int i = 0; i < n_top; i++) { +// // const llama_token id = cur_p.data[i].id; +// // (void)id; // To avoid a warning that id is unused when logging is disabled. +// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); +// // } +// //} + +// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); +// } +// } + +// smpl->n_sample++; - return res; -} +// return res; +//} -void llama_sampling_accept( - struct llama_sampling * smpl, - llama_token token, - bool apply_grammar) { - time_meas tm(smpl->t_accept_us); +//void llama_sampling_accept( +// struct llama_sampling * smpl, +// llama_token token, +// bool apply_grammar) { +// time_meas tm(smpl->t_accept_us); - llama_sampling_accept_impl(*smpl, token, apply_grammar); +// llama_sampling_accept_impl(*smpl, token, apply_grammar); - smpl->n_accept++; -} +// smpl->n_accept++; +//} -int llama_sampling_n_prev(const struct llama_sampling * smpl) { - return llama_sampling_n_prev_impl(*smpl); -} +//int llama_sampling_n_prev(const struct llama_sampling * smpl) { +// return llama_sampling_n_prev_impl(*smpl); +//} -llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { - return llama_sampling_prev_impl(*smpl, ith); -} +//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { +// return llama_sampling_prev_impl(*smpl, ith); +//} -llama_token llama_sampling_last(const struct llama_sampling * smpl) { - return llama_sampling_prev_impl(*smpl, 0); -} +//llama_token llama_sampling_last(const struct llama_sampling * smpl) { +// return llama_sampling_prev_impl(*smpl, 0); +//} // // sampling v2 @@ -21006,28 +20976,32 @@ struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta return llama_constraint_init_temp_ext_impl(temp, delta, exponent); } -struct llama_constraint * llama_constraint_init_grammar(struct llama_model * model, const char * grammar_str, const char * grammar_root) { +struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } struct llama_constraint * llama_constraint_init_penalties( - struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { + const struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); } LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } +struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr) { + return llama_constraint_cp_impl(*cnstr); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return; @@ -21110,7 +21084,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok llama_token res; if (type == 1) { - res = llama_sampling_sample_mirostat_impl(candidates, + res = llama_sampler_sample_mirostat_impl(candidates, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -21118,7 +21092,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok smpl->vocab->n_vocab, smpl->mirostat_mu); } else if (type == 2) { - res = llama_sampling_sample_mirostat_v2_impl(candidates, + res = llama_sampler_sample_mirostat_v2_impl(candidates, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -21132,14 +21106,14 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok return res; } -llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates) { +llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs) { time_meas tm(smpl->t_sample_us); if (candidates == nullptr) { candidates = &smpl->cur_p; } - auto res = llama_sampling_sample_greedy_impl(candidates); + auto res = llama_sampler_sample_greedy_impl(candidates, probs); smpl->n_sample++; @@ -21153,7 +21127,7 @@ llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_d candidates = &smpl->cur_p; } - auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + auto res = llama_sampler_sample_dist_impl(candidates, smpl->rng); smpl->n_sample++; @@ -21204,20 +21178,16 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) { +void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { const llama_timings timings = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), - /*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0), - /*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0), /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), - /*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0), - /*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -21226,10 +21196,6 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smp LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); - LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar); - //LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - // __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -21237,15 +21203,13 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smp LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; if (smpl) { smpl->t_sample_us = smpl->n_sample = 0; - smpl->t_grammar_us = smpl->n_grammar = 0; - smpl->t_accept_us = smpl->n_accept = 0; } }