diff --git a/common/common.cpp b/common/common.cpp index f7095c7f3..2a51649a5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -841,15 +841,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.defrag_thold = std::stof(argv[i]); return true; } - if (arg == "--samplers" || arg == "--constraints") { + if (arg == "--samplers") { CHECK_ARG - const auto constraint_names = string_split(argv[i], ';'); - sparams.constraints = gpt_constraint_types_from_names(constraint_names, true); + const auto sampler_names = string_split(argv[i], ';'); + sparams.samplers = gpt_sampler_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.constraints = gpt_constraint_types_from_chars(argv[i]); + sparams.samplers = gpt_sampler_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -1706,13 +1706,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const auto & sparams = params.sparams; - std::string constraint_type_chars; - std::string constraint_type_names; - for (const auto & constraint : sparams.constraints) { - constraint_type_chars += gpt_constraint_type_to_chr(constraint); - constraint_type_names += gpt_constraint_type_to_str(constraint) + ";"; + std::string sampler_type_chars; + std::string sampler_type_names; + for (const auto & sampler : sparams.samplers) { + sampler_type_chars += gpt_sampler_type_to_chr(sampler); + sampler_type_names += gpt_sampler_type_to_str(sampler) + ";"; } - constraint_type_names.pop_back(); + sampler_type_names.pop_back(); struct option_info { LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) @@ -1826,9 +1826,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "sampling" }); options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" - "(default: %s)", constraint_type_names.c_str() }); + "(default: %s)", sampler_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", - "simplified sequence for samplers that will be used (default: %s)", constraint_type_chars.c_str() }); + "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); diff --git a/common/sampling.cpp b/common/sampling.cpp index 0047ead34..de7c9b1b9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,14 +2,127 @@ #include "common.h" +// the ring buffer works similarly to std::deque, but with a fixed capacity +// TODO: deduplicate with llama-impl.h +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; + struct gpt_sampler { gpt_sampler_params params; - struct llama_constraint * bias; - struct llama_constraint * pnlt; - struct llama_constraint * grmr; + struct llama_sampler * bias; + struct llama_sampler * pnlt; + struct llama_sampler * grmr; - struct llama_sampler * smpl; + struct llama_sampler * chain; + + ring_buffer prev; + + std::vector cur; + + llama_token_data_array cur_p; + + void set_logits(struct llama_context * ctx, int idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + cur.resize(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + } }; std::string gpt_sampler_params::print() const { @@ -29,28 +142,26 @@ std::string gpt_sampler_params::print() const { std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { std::string result = "\tlogits"; - for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) { - const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i); - result += std::string(" -> ") + llama_constraint_name(cnstr) + " "; + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + result += std::string(" -> ") + llama_sampler_name(smpl) + " "; } return result; } struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { - llama_sampler_params lparams = llama_sampler_default_params(); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; - lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST; + lparams.no_timing = false; auto * result = new gpt_sampler { /* .params = */ params, - /* .bias = */ llama_constraint_init_logit_bias( + /* .bias = */ llama_sampler_init_logit_bias( model, params.logit_bias.size(), params.logit_bias.data()), - /* .pnlt = */ llama_constraint_init_penalties( + /* .pnlt = */ llama_sampler_init_penalties( model, params.penalty_last_n, params.penalty_repeat, @@ -58,45 +169,53 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st params.penalty_present, params.penalize_nl, params.ignore_eos), - /* .grmr = */ llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), - /* .smpl = */ llama_sampler_init(model, lparams) + /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .chain = */ llama_sampler_chain_init(lparams), + /* .prev = */ ring_buffer(params.n_prev), + /* .cur = */ {}, + /* .cur_p = */ {}, }; if (params.temp > 0.0f) { if (params.mirostat == 0) { - for (const auto & cnstr : params.constraints) { + for (const auto & cnstr : params.samplers) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k)); + case GPT_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + case GPT_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + case GPT_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + case GPT_SAMPLER_TYPE_TFS_Z: + llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + case GPT_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + case GPT_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; default: - GGML_ASSERT(false && "unknown constraint type"); + GGML_ASSERT(false && "unknown sampler type"); } } } else if (params.mirostat == 1) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } else { + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } return result; @@ -104,11 +223,11 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { - llama_constraint_free(gsmpl->bias); - llama_constraint_free(gsmpl->pnlt); - llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->bias); + llama_sampler_free(gsmpl->pnlt); + llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->smpl); + llama_sampler_free(gsmpl->chain); delete gsmpl; } @@ -117,69 +236,66 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, - /* .bias = */ llama_constraint_clone(gsmpl->bias), - /* .pnlt = */ llama_constraint_clone(gsmpl->pnlt), - /* .grmr = */ llama_constraint_clone(gsmpl->grmr), - /* .smpl = */ llama_sampler_clone (gsmpl->smpl) + /* .bias = */ llama_sampler_clone(gsmpl->bias), + /* .pnlt = */ llama_sampler_clone(gsmpl->pnlt), + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { if (apply_grammar) { - llama_constraint_accept(gsmpl->grmr, token); + llama_sampler_accept(gsmpl->grmr, token); } - llama_sampler_accept(gsmpl->smpl, token); + llama_sampler_accept(gsmpl->chain, token); + + gsmpl->prev.push_back(token); } void gpt_sampler_reset(struct gpt_sampler * gsmpl) { - llama_constraint_reset(gsmpl->grmr); + llama_sampler_reset(gsmpl->grmr); - llama_sampler_reset(gsmpl->smpl); -} - -void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) { - llama_sampler_set_logits(gsmpl->smpl, logits); + llama_sampler_reset(gsmpl->chain); } llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { - return llama_sampler_get_candidates(gsmpl->smpl); + return &gsmpl->cur_p; } llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { - return llama_sampler_last(gsmpl->smpl); + return gsmpl->prev.rat(0); } -void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { - llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); -} - -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) { - return llama_sampler_sample(gsmpl->smpl, cur_p); +void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { + llama_print_timings(ctx, gsmpl ? gsmpl->chain : nullptr); } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & bias = gsmpl->bias; - auto & pnlt = gsmpl->pnlt; - auto & grmr = gsmpl->grmr; - auto & smpl = gsmpl->smpl; + auto & bias = gsmpl->bias; + auto & pnlt = gsmpl->pnlt; + auto & grmr = gsmpl->grmr; + auto & chain = gsmpl->chain; - const auto * logits = llama_get_logits_ith(ctx, idx); + gsmpl->set_logits(ctx, idx); - llama_sampler_set_logits(smpl, logits); + auto & cur_p = gsmpl->cur_p; - auto * cur_p = llama_sampler_get_candidates(smpl); - - llama_constraint_apply(bias, cur_p); - llama_constraint_apply(pnlt, cur_p); + llama_sampler_apply(bias, &cur_p); + llama_sampler_apply(pnlt, &cur_p); if (grammar_first) { - llama_constraint_apply(grmr, cur_p); + llama_sampler_apply(grmr, &cur_p); } - llama_sampler_apply(smpl, cur_p); + llama_sampler_apply(chain, &cur_p); - const llama_token id = llama_sampler_sample(smpl, cur_p); + const llama_token id = cur_p.data[cur_p.selected].id; + + GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); if (grammar_first) { return id; @@ -188,9 +304,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false }; - llama_constraint_apply(grmr, &single_token_data_array); + llama_sampler_apply(grmr, &single_token_data_array); // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -199,28 +315,22 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context } } - // if the token is not valid, sample again, first apply the grammar constraints and then sample - llama_sampler_set_logits(smpl, logits); + // if the token is not valid, sample again, first apply the grammar samplers and then sample + gsmpl->set_logits(ctx, idx); - llama_constraint_apply(bias, cur_p); - llama_constraint_apply(pnlt, cur_p); - llama_constraint_apply(grmr, cur_p); + llama_sampler_apply(bias, &cur_p); + llama_sampler_apply(pnlt, &cur_p); + llama_sampler_apply(grmr, &cur_p); - llama_sampler_apply(smpl, cur_p); + llama_sampler_apply(chain, &cur_p); - return llama_sampler_sample(smpl, cur_p); -} + GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { - GGML_ASSERT(cur_p != nullptr); - - llama_constraint_apply(gsmpl->grmr, cur_p); + return cur_p.data[cur_p.selected].id; } std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { - auto & smpl = gsmpl->smpl; - - n = std::min(n, llama_sampler_n_prev(smpl)); + n = std::min(n, (int) gsmpl->prev.size()); if (n <= 0) { return ""; @@ -230,7 +340,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab for (int i = n - 1; i >= 0; i--) { - const llama_token id = llama_sampler_prev(smpl, i); + const llama_token id = gsmpl->prev.rat(i); GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); @@ -240,95 +350,95 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, return result; } -char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr) { +char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: return 'k'; - case GPT_CONSTRAINT_TYPE_TFS_Z: return 'f'; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; - case GPT_CONSTRAINT_TYPE_TOP_P: return 'p'; - case GPT_CONSTRAINT_TYPE_MIN_P: return 'm'; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: return 't'; + case GPT_SAMPLER_TYPE_TOP_K: return 'k'; + case GPT_SAMPLER_TYPE_TFS_Z: return 'f'; + case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case GPT_SAMPLER_TYPE_TOP_P: return 'p'; + case GPT_SAMPLER_TYPE_MIN_P: return 'm'; + case GPT_SAMPLER_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr) { +std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: return "top_k"; - case GPT_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; - case GPT_CONSTRAINT_TYPE_TOP_P: return "top_p"; - case GPT_CONSTRAINT_TYPE_MIN_P: return "min_p"; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; + case GPT_SAMPLER_TYPE_TOP_K: return "top_k"; + case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z"; + case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case GPT_SAMPLER_TYPE_TOP_P: return "top_p"; + case GPT_SAMPLER_TYPE_MIN_P: return "min_p"; + case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map constraint_canonical_name_map { - { "top_k", GPT_CONSTRAINT_TYPE_TOP_K }, - { "top_p", GPT_CONSTRAINT_TYPE_TOP_P }, - { "typ_p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "min_p", GPT_CONSTRAINT_TYPE_MIN_P }, - { "tfs_z", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "temperature", GPT_CONSTRAINT_TYPE_TEMPERATURE }, +std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "top_k", GPT_SAMPLER_TYPE_TOP_K }, + { "top_p", GPT_SAMPLER_TYPE_TOP_P }, + { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", GPT_SAMPLER_TYPE_MIN_P }, + { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z }, + { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE }, }; - // since constraints names are written multiple ways + // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map constraint_alt_name_map { - { "top-k", GPT_CONSTRAINT_TYPE_TOP_K }, - { "top-p", GPT_CONSTRAINT_TYPE_TOP_P }, - { "nucleus", GPT_CONSTRAINT_TYPE_TOP_P }, - { "typical-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typical", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "min-p", GPT_CONSTRAINT_TYPE_MIN_P }, - { "tfs-z", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "tfs", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "temp", GPT_CONSTRAINT_TYPE_TEMPERATURE }, + std::unordered_map sampler_alt_name_map { + { "top-k", GPT_SAMPLER_TYPE_TOP_K }, + { "top-p", GPT_SAMPLER_TYPE_TOP_P }, + { "nucleus", GPT_SAMPLER_TYPE_TOP_P }, + { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typical", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typ", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", GPT_SAMPLER_TYPE_MIN_P }, + { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z }, + { "tfs", GPT_SAMPLER_TYPE_TFS_Z }, + { "temp", GPT_SAMPLER_TYPE_TEMPERATURE }, }; - std::vector constraints; - constraints.reserve(names.size()); + std::vector samplers; + samplers.reserve(names.size()); for (const auto & name : names) { - auto constraint = constraint_canonical_name_map.find(name); - if (constraint != constraint_canonical_name_map.end()) { - constraints.push_back(constraint->second); + auto sampler = sampler_canonical_name_map.find(name); + if (sampler != sampler_canonical_name_map.end()) { + samplers.push_back(sampler->second); } else { if (allow_alt_names) { - constraint = constraint_alt_name_map.find(name); - if (constraint != constraint_alt_name_map.end()) { - constraints.push_back(constraint->second); + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); } } } } - return constraints; + return samplers; } -std::vector gpt_constraint_types_from_chars(const std::string & chars) { - std::unordered_map constraint_name_map { - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_K), GPT_CONSTRAINT_TYPE_TOP_K }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TFS_Z), GPT_CONSTRAINT_TYPE_TFS_Z }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TYPICAL_P), GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_P), GPT_CONSTRAINT_TYPE_TOP_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_MIN_P), GPT_CONSTRAINT_TYPE_MIN_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TEMPERATURE), GPT_CONSTRAINT_TYPE_TEMPERATURE } +std::vector gpt_sampler_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map { + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE } }; - std::vector constraints; - constraints.reserve(chars.size()); + std::vector samplers; + samplers.reserve(chars.size()); for (const auto & c : chars) { - const auto constraint = constraint_name_map.find(c); - if (constraint != constraint_name_map.end()) { - constraints.push_back(constraint->second); + const auto sampler = sampler_name_map.find(c); + if (sampler != sampler_name_map.end()) { + samplers.push_back(sampler->second); } } - return constraints; + return samplers; } diff --git a/common/sampling.h b/common/sampling.h index c260ef055..5083f456f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -5,14 +5,14 @@ #include #include -enum gpt_constraint_type { - GPT_CONSTRAINT_TYPE_NONE = 0, - GPT_CONSTRAINT_TYPE_TOP_K = 1, - GPT_CONSTRAINT_TYPE_TOP_P = 2, - GPT_CONSTRAINT_TYPE_MIN_P = 3, - GPT_CONSTRAINT_TYPE_TFS_Z = 4, - GPT_CONSTRAINT_TYPE_TYPICAL_P = 5, - GPT_CONSTRAINT_TYPE_TEMPERATURE = 6, +enum gpt_sampler_type { + GPT_SAMPLER_TYPE_NONE = 0, + GPT_SAMPLER_TYPE_TOP_K = 1, + GPT_SAMPLER_TYPE_TOP_P = 2, + GPT_SAMPLER_TYPE_MIN_P = 3, + GPT_SAMPLER_TYPE_TFS_Z = 4, + GPT_SAMPLER_TYPE_TYPICAL_P = 5, + GPT_SAMPLER_TYPE_TEMPERATURE = 6, }; // sampling parameters @@ -21,7 +21,7 @@ struct gpt_sampler_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise constraints should return at least min_keep tokens + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -40,13 +40,13 @@ struct gpt_sampler_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector constraints = { - GPT_CONSTRAINT_TYPE_TOP_K, - GPT_CONSTRAINT_TYPE_TFS_Z, - GPT_CONSTRAINT_TYPE_TYPICAL_P, - GPT_CONSTRAINT_TYPE_TOP_P, - GPT_CONSTRAINT_TYPE_MIN_P, - GPT_CONSTRAINT_TYPE_TEMPERATURE + std::vector samplers = { + GPT_SAMPLER_TYPE_TOP_K, + GPT_SAMPLER_TYPE_TFS_Z, + GPT_SAMPLER_TYPE_TYPICAL_P, + GPT_SAMPLER_TYPE_TOP_P, + GPT_SAMPLER_TYPE_MIN_P, + GPT_SAMPLER_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -73,40 +73,36 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); - -void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); - llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); +//llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); -void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); +void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); // extended sampling implementation: // // - set logits -// - apply the configured sampling constraints +// - apply the configured sampler chain // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -// if grammar_first is true, the grammar is applied before the constraints (slower) +// if grammar_first is true, the grammar is applied before the samplers (slower) // useful in cases where all the resulting candidates must fit the grammar // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers -// print the constraints into a string +// print the sampler chain into a string std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); // get a string representation of the last accepted tokens std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); -char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr); -std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr); +char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); +std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); -std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector gpt_constraint_types_from_chars(const std::string & chars); +std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector gpt_sampler_types_from_chars(const std::string & chars); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index a02fa4da9..24f8a7027 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -50,9 +50,9 @@ defer { llama_free(context) } -var sparams = llama_sampler_params() +var sparams = llama_sampler_chain_default_params() -let smpl = llama_sampler_init(model, sparams) +let smpl = llama_sampler_chain_init(sparams) guard smpl != nil else { print("Failed to initialize sampling") exit(1) @@ -61,9 +61,9 @@ defer { llama_sampler_free(smpl) } -llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40)); -llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); -llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); +llama_sampler_sampler_add(smpl, llama_sampler_init_top_k(40)); +llama_sampler_sampler_add(smpl, llama_sampler_init_top_p(0.9, 1)); +llama_sampler_sampler_add(smpl, llama_sampler_init_temp (0.4)); let n_ctx = llama_n_ctx(context) @@ -137,11 +137,9 @@ while n_cur <= n_len { continue } - var logits = llama_get_logits_ith(context, i_batch[i]) + let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) - llama_sampler_set_logits(smpl, logits) - - let new_token_id = llama_sampler_sample(smpl, nil) + llama_sampler_accept(smpl, new_token_id) // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 5896526ab..b6e98fcc3 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,15 +64,13 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.seed = params.sparams.seed; + llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler * smpl = llama_sampler_init(model, sparams); - - llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k)); - llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -173,11 +171,9 @@ int main(int argc, char ** argv) { continue; } - const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); - llama_sampler_set_logits(smpl, logits); - - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 07475ecd3..b402abbb8 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -120,11 +120,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode(ctx, bat); - const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_sampler_accept(smpl, token); - llama_sampler_set_logits(smpl, logits); - - llama_token token = llama_sampler_sample(smpl, nullptr); if (token == eos_token) { break; } @@ -171,11 +169,9 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; - - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 1a4908501..921793751 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -394,12 +394,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - llama_sampler_set_logits(sampling, logits); - // sample the most likely token - const auto new_token_id = llama_sampler_sample(sampling, nullptr); + const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1); + + llama_sampler_accept(sampling, new_token_id); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index bd6513d34..73cabc6c7 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,9 +43,8 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - var sparams = llama_sampler_default_params() - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY - self.sampling = llama_sampler_init(context, sparams) + var sparams = llama_sampler_chain_default_params() + self.sampling = llama_sampler_chain_init(sparams) } deinit { @@ -148,12 +147,9 @@ actor LlamaContext { func completion_loop() -> String { var new_token_id: llama_token = 0 - let n_vocab = llama_n_vocab(model) - let logits = llama_get_logits_ith(context, batch.n_tokens - 1) + new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - llama_sampler_set_logits(sampling, logits); - - new_token_id = llama_sampler_sample(sampling, nil) + llama_sampler_accept(sampling, new_token_id) if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index b9800a917..92c71c5a1 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,11 +83,11 @@ int main(int argc, char ** argv) { return 1; } - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); // tokenize the prompt std::vector tokens_list; @@ -220,12 +220,9 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_set_logits(smpl, logits); - - // sample the most likely token - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6f8c84137..133a010e4 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -38,10 +38,12 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler_params sparams = llama_sampler_default_params(); - sparams.seed = params.sparams.seed; + auto sparams = llama_sampler_chain_default_params(); - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -69,13 +71,11 @@ int main(int argc, char ** argv) { printf("\nfirst run: %s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx); - - llama_sampler_set_logits(smpl, logits); - - auto next_token = llama_sampler_sample(smpl, nullptr); + auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = llama_token_to_piece(ctx, next_token); + llama_sampler_accept(smpl, next_token); + printf("%s", next_token_str.c_str()); result0 += next_token_str; @@ -96,7 +96,10 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampler * smpl2 = llama_sampler_init(model, sparams); + llama_sampler * smpl2 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl2, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -126,13 +129,11 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx2); - - llama_sampler_set_logits(smpl2, logits); - - auto next_token = llama_sampler_sample(smpl2, nullptr); + auto next_token = llama_sampler_sample(smpl2, ctx2, -1); auto next_token_str = llama_token_to_piece(ctx2, next_token); + llama_sampler_accept(smpl2, next_token); + printf("%s", next_token_str.c_str()); result1 += next_token_str; @@ -157,7 +158,10 @@ int main(int argc, char ** argv) { // make new context auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampler * smpl3 = llama_sampler_init(model, sparams); + llama_sampler * smpl3 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl3, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -215,13 +219,11 @@ int main(int argc, char ** argv) { // third run with seq 1 instead of 0 for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx3); - - llama_sampler_set_logits(smpl3, logits); - - auto next_token = llama_sampler_sample(smpl3, nullptr); + auto next_token = llama_sampler_sample(smpl3, ctx3, -1); auto next_token_str = llama_token_to_piece(ctx3, next_token); + llama_sampler_accept(smpl3, next_token); + printf("%s", next_token_str.c_str()); result2 += next_token_str; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 03e512e03..1095f43b2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1027,17 +1027,17 @@ struct server_context { } { - const auto & constraints = data.find("samplers"); - if (constraints != data.end() && constraints->is_array()) { - std::vector constraint_names; - for (const auto & name : *constraints) { + const auto & samplers = data.find("samplers"); + if (samplers != data.end() && samplers->is_array()) { + std::vector sampler_names; + for (const auto & name : *samplers) { if (name.is_string()) { - constraint_names.emplace_back(name); + sampler_names.emplace_back(name); } } - slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false); + slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false); } else { - slot.sparams.constraints = default_sparams.constraints; + slot.sparams.samplers = default_sparams.samplers; } } @@ -1253,10 +1253,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - std::vector constraints; - constraints.reserve(slot.sparams.constraints.size()); - for (const auto & constraint : slot.sparams.constraints) { - constraints.emplace_back(gpt_constraint_type_to_str(constraint)); + std::vector samplers; + samplers.reserve(slot.sparams.samplers.size()); + for (const auto & sampler : slot.sparams.samplers) { + samplers.emplace_back(gpt_sampler_type_to_str(sampler)); } return json { @@ -1290,7 +1290,7 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", constraints}, + {"samplers", samplers}, }; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 7193f1ee4..e5dfeb2f4 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,11 +55,9 @@ int main(int argc, char ** argv) { return 1; } - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; - - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); // tokenize the prompt @@ -116,12 +114,9 @@ int main(int argc, char ** argv) { while (n_cur <= n_predict) { // sample the next token { - const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_set_logits(smpl, logits); - - // sample the most likely token - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 9f596ec91..037d5d34b 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); - struct llama_constraint * softmax = llama_constraint_init_softmax(); + struct llama_sampler * softmax = llama_sampler_init_softmax(); // draft sequence data std::vector drafts(n_seq_dft); @@ -255,7 +255,7 @@ int main(int argc, char ** argv) { LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); - llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true }; //GGML_ASSERT(dist_tgt.size <= dist_dft.size); @@ -625,7 +625,7 @@ int main(int argc, char ** argv) { gpt_sampler_free(drafts[s].smpl); } - llama_constraint_free(softmax); + llama_sampler_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index dd047e0ac..f73541389 100644 --- a/include/llama.h +++ b/include/llama.h @@ -216,6 +216,7 @@ extern "C" { // TODO: consider SoA llama_token_data * data; size_t size; + int64_t selected; bool sorted; } llama_token_data_array; @@ -369,21 +370,9 @@ extern "C" { float bias; } llama_logit_bias; - enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_GREEDY = 0, - LLAMA_SAMPLER_TYPE_DIST = 1, - }; - - typedef struct llama_sampler_params { - uint32_t seed; // the seed used to initialize the rng of the sampler - - int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) - - // TODO: will be used by the llama_decode_with_sampler() API in the future - enum llama_sampler_type type; - + typedef struct llama_sampler_chain_params { bool no_timing; // whether to measure performance timings - } llama_sampler_params; + } llama_sampler_chain_params; // performance timing information struct llama_timings { @@ -412,7 +401,7 @@ extern "C" { // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); - LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); + LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -1003,70 +992,73 @@ extern "C" { // // Sampling API // - // - Constraints - // The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their - // logits and probabilities inplace. The interface is abstracted so that users can implement custom constraints. - // - // - Samplers - // The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the - // sampler can apply a sequence of constraints in order to modify the probabilities of the candidates. - // - // The llama_sampler object contains the entire sampling information: - // - // - RNG state (seed and generator) - // - Custom set of constraints (see llama_sampler_constraint_add) - // - Sampling method (greedy, dist) - // - Previous tokens - // // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). // // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model - // constraints + typedef void * llama_sampler_context_t; - struct llama_constraint; - - typedef void * llama_constraint_context_t; - - // user code can implement the interface below in order to create custom llama_constraint - struct llama_constraint_i { - const char * (*name) (const struct llama_constraint * cnstr); // can be NULL - void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required - void (*reset) ( struct llama_constraint * cnstr); // can be NULL - struct llama_constraint * (*clone) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL - void (*free) ( struct llama_constraint * cnstr); // can be NULL if ctx is NULL + // user code can implement the interface below in order to create custom llama_sampler + struct llama_sampler_i { + const char * (*name) (const struct llama_sampler * smpl); // can be NULL + void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required + void (*reset) ( struct llama_sampler * smpl); // can be NULL + struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; - struct llama_constraint { - struct llama_constraint_i * iface; - llama_constraint_context_t ctx; + struct llama_sampler { + struct llama_sampler_i * iface; + llama_sampler_context_t ctx; }; + LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); + // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + + // llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers + + LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); + + // important: takes ownership of the sampler object and will free it when llama_sampler_free is called + LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + + // available samplers: + + LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); + LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k); + LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -1074,7 +1066,7 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API struct llama_constraint * llama_constraint_init_mirostat( + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( const struct llama_model * model, float tau, float eta); @@ -1084,16 +1076,16 @@ extern "C" { /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2( + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( float tau, float eta); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar( + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_model * model, const char * grammar_str, const char * grammar_root); - LLAMA_API struct llama_constraint * llama_constraint_init_penalties( + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( const struct llama_model * model, int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat, // 1.0 = disabled @@ -1102,57 +1094,14 @@ extern "C" { bool penalize_nl, // consider newlines as a repeatable token bool ignore_eos); // ignore the end-of-sequence token - LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( const struct llama_model * model, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - LLAMA_API struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr); - - // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) - LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); - - LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr); - LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token); - LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); - LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr); - - // samplers - - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); - LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); - - LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); - - LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); - - // important: takes ownership of the constraint object and will free it in llama_sampler_free - LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr); - LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl); - LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); - - - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p); - - /// @details Get the number of accepted tokens so far (max of n_prev) - LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); - - /// @details Get the ith accepted token - /// @param ith [0, n_prev), ith == 0 is the last accepted token. - /// returns LLAMA_TOKEN_NULL if ith is out of bounds - LLAMA_API llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith); - - /// @details Get the last accepted token - /// Same as llama_sampler_prev(smpl, 0) - /// returns LLAMA_TOKEN_NULL if there are no accepted tokens - LLAMA_API llama_token llama_sampler_last(const struct llama_sampler * smpl); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future - //LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t i); //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); // @@ -1172,8 +1121,9 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl); - LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl); + // note: requires llama_sampler_chain. how to prevent misuse? + LLAMA_API void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain); + LLAMA_API void llama_reset_timings( struct llama_context * ctx, struct llama_sampler * chain); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama-impl.h b/src/llama-impl.h index 6d388655d..fa2e09e1f 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -32,6 +32,20 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * // helpers // +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { return; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index cf28baab5..735992faa 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include static void llama_log_softmax(float * array, size_t size) { @@ -24,7 +25,7 @@ static void llama_log_softmax(float * array, size_t size) { } } -static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); // Sort the logits in descending order @@ -49,7 +50,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { } } -static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) { +static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; @@ -125,12 +126,12 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t cur_p->size = k; } -static void llama_constraint_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -151,7 +152,7 @@ static void llama_constraint_top_p_impl(llama_token_data_array * cur_p, float p, cur_p->size = last_idx; } -static void llama_constraint_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { if (p <= 0.0f || !cur_p->size) { return; } @@ -206,12 +207,12 @@ static void llama_constraint_min_p_impl(llama_token_data_array * cur_p, float p, } } -static void llama_constraint_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { +static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { if (z >= 1.0f || cur_p->size <= 2) { return; } - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Compute the first and second derivatives std::vector first_derivatives(cur_p->size - 1); @@ -260,7 +261,7 @@ static void llama_constraint_tail_free_impl(llama_token_data_array * cur_p, floa cur_p->size = last_idx; } -static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -268,7 +269,7 @@ static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float } // Compute the softmax of logits and calculate entropy - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); float entropy = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -318,7 +319,7 @@ static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float cur_p->sorted = false; } -static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { +static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if (cur_p->size <= 1) { return; @@ -327,7 +328,7 @@ static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float // Calculate maximum possible entropy float max_entropy = -logf(1.0f / cur_p->size); - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -381,17 +382,17 @@ static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float #endif } -static void llama_constraint_temp_impl(llama_token_data_array * cur_p, float temp) { +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { for (size_t i = 0; i < cur_p->size; ++i) { cur_p->data[i].logit /= temp; } } -static void llama_constraint_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { +static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { llama_grammar_apply_impl(grammar, cur_p); } -void llama_constraint_penalties_impl( +void llama_sampler_penalties_impl( llama_token_data_array * cur_p, const llama_token_cnt & token_count, float penalty_repeat, @@ -421,56 +422,124 @@ void llama_constraint_penalties_impl( } // -// constraints +// samplers // -// softmax +// greedy -static struct llama_constraint_i llama_constraint_softmax_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; }, +static struct llama_sampler_i llama_sampler_greedy_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) { - llama_constraint_softmax_impl(cur_p); + /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } }, /* .reset = */ nullptr, /* .clone = */ nullptr, /* .free = */ nullptr, }; -struct llama_constraint * llama_constraint_init_softmax_impl() { - return new llama_constraint { - /* .iface = */ &llama_constraint_softmax_i, +struct llama_sampler * llama_sampler_init_greedy_impl() { + return new llama_sampler { + /* .iface = */ &llama_sampler_greedy_i, + /* .ctx = */ nullptr, + }; +} + +// dist + +struct llama_sampler_context_dist { + const uint32_t seed; + + std::mt19937 rng; +}; + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; }, + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_dist *) smpl->ctx; + std::vector probs; + probs.reserve(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs.push_back(cur_p->data[i].p); + } + + std::discrete_distribution dist(probs.begin(), probs.end()); + + cur_p->selected = dist(ctx->rng); + }, + /* .reset = */ nullptr, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; + return llama_sampler_init_dist_impl(ctx->seed); + }, + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_dist *) smpl->ctx; + }, +}; + +struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { + return new llama_sampler { + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_context_dist { + /* .seed = */ seed, + /* .rng = */ std::mt19937(seed), + }, + }; +} + +// softmax + +static struct llama_sampler_i llama_sampler_softmax_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; }, + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + llama_sampler_softmax_impl(cur_p); + }, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_softmax_impl() { + return new llama_sampler { + /* .iface = */ &llama_sampler_softmax_i, /* .ctx = */ nullptr, }; } // top-k -struct llama_constraint_context_top_k { +struct llama_sampler_context_top_k { const int32_t k; }; -static struct llama_constraint_i llama_constraint_top_k_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; }, +static struct llama_sampler_i llama_sampler_top_k_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_constraint_top_k_impl(cur_p, ctx->k); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; - return llama_constraint_init_top_k_impl(ctx->k); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; + return llama_sampler_init_top_k_impl(ctx->k); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_k *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_top_k *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { - return new llama_constraint { - /* .iface = */ &llama_constraint_top_k_i, - /* .ctx = */ new llama_constraint_context_top_k { +struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_k_i, + /* .ctx = */ new llama_sampler_context_top_k { /* .k = */ k, }, }; @@ -478,32 +547,32 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { // top-p -struct llama_constraint_context_top_p { +struct llama_sampler_context_top_p { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_top_p_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; }, +static struct llama_sampler_i llama_sampler_top_p_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; - llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_top_p *) smpl->ctx; + llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx; - return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; + return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_p *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_top_p *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_top_p_i, - /* .ctx = */ new llama_constraint_context_top_p { +struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_p_i, + /* .ctx = */ new llama_sampler_context_top_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -512,32 +581,32 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k // min-p -struct llama_constraint_context_min_p { +struct llama_sampler_context_min_p { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_min_p_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; }, +static struct llama_sampler_i llama_sampler_min_p_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; - llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_min_p *) smpl->ctx; + llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx; - return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; + return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_min_p *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_min_p *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_min_p_i, - /* .ctx = */ new llama_constraint_context_min_p { +struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_min_p_i, + /* .ctx = */ new llama_sampler_context_min_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -546,32 +615,32 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k // tail-free -struct llama_constraint_context_tail_free { +struct llama_sampler_context_tail_free { const float z; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_tail_free_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; }, +static struct llama_sampler_i llama_sampler_tail_free_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; - llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_tail_free *) smpl->ctx; + llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx; - return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; + return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_tail_free *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_tail_free *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_tail_free_i, - /* .ctx = */ new llama_constraint_context_tail_free { +struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_tail_free_i, + /* .ctx = */ new llama_sampler_context_tail_free { /* .z = */ z, /*. min_keep = */ min_keep, }, @@ -580,32 +649,32 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m // typical -struct llama_constraint_context_typical { +struct llama_sampler_context_typical { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_typical_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; }, +static struct llama_sampler_i llama_sampler_typical_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; - llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_typical *) smpl->ctx; + llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx; - return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; + return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_typical *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_typical *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_typical_i, - /* .ctx = */ new llama_constraint_context_typical { +struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_typical_i, + /* .ctx = */ new llama_sampler_context_typical { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -614,31 +683,31 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min // temp -struct llama_constraint_context_temp { +struct llama_sampler_context_temp { const float temp; }; -static struct llama_constraint_i llama_constraint_temp_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; }, +static struct llama_sampler_i llama_sampler_temp_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; - llama_constraint_temp_impl(cur_p, ctx->temp); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_temp *) smpl->ctx; + llama_sampler_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx; - return llama_constraint_init_temp_impl(ctx->temp); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; + return llama_sampler_init_temp_impl(ctx->temp); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_temp *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_temp *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_temp_impl(float temp) { - return new llama_constraint { - /* .iface = */ &llama_constraint_temp_i, - /* .ctx = */ new llama_constraint_context_temp { +struct llama_sampler * llama_sampler_init_temp_impl(float temp) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_i, + /* .ctx = */ new llama_sampler_context_temp { /*.temp = */ temp, }, }; @@ -646,40 +715,40 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { // temp-ext -struct llama_constraint_context_temp_ext { +struct llama_sampler_context_temp_ext { const float temp; const float delta; const float exponent; }; -static struct llama_constraint_i llama_constraint_temp_ext_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; }, +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; - llama_constraint_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); + llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); } else { - llama_constraint_temp_impl(cur_p, ctx->temp); + llama_sampler_temp_impl(cur_p, ctx->temp); } }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx; - return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_temp_ext *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_temp_ext *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { - return new llama_constraint { - /* .iface = */ &llama_constraint_temp_ext_i, - /* .ctx = */ new llama_constraint_context_temp_ext { +struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_ext_i, + /* .ctx = */ new llama_sampler_context_temp_ext { /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, @@ -689,7 +758,7 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float // mirostat -struct llama_constraint_context_mirostat { +struct llama_sampler_context_mirostat { const struct llama_vocab * vocab; const float tau; @@ -702,10 +771,10 @@ struct llama_constraint_context_mirostat { std::vector cur; }; -static struct llama_constraint_i llama_constraint_mirostat_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; int32_t idx = -1; for (size_t i = 0; i < ctx->cur.size(); ++i) { @@ -721,10 +790,10 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { // Update mu using the learning rate and error ctx->mu = ctx->mu - ctx->eta * e; }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -742,7 +811,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { float epsilon_hat = s_hat - 1; float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); - llama_constraint_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -750,23 +819,23 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { ctx->cur[i] = cur_p->data[i]; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; - return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx; + return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_mirostat *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_mirostat *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { - return new llama_constraint { - /* .iface = */ &llama_constraint_mirostat_i, - /* .ctx = */ new llama_constraint_context_mirostat { +struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_i, + /* .ctx = */ new llama_sampler_context_mirostat { /* .vocab = */ &vocab, /* .tau = */ tau, /* .eta = */ eta, @@ -779,7 +848,7 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama // mirostat v2 -struct llama_constraint_context_mirostat_v2 { +struct llama_sampler_context_mirostat_v2 { const float tau; const float eta; @@ -788,10 +857,10 @@ struct llama_constraint_context_mirostat_v2 { std::vector cur; }; -static struct llama_constraint_i llama_constraint_mirostat_v2_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat-v2"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; int32_t idx = -1; for (size_t i = 0; i < ctx->cur.size(); ++i) { @@ -807,10 +876,10 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { // Update mu using the learning rate and error ctx->mu = ctx->mu - ctx->eta * e; }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Truncate the words with surprise values greater than mu cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { @@ -822,7 +891,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { } // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -830,23 +899,23 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { ctx->cur[i] = cur_p->data[i]; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; - return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; + return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { - return new llama_constraint { - /* .iface = */ &llama_constraint_mirostat_v2_i, - /* .ctx = */ new llama_constraint_context_mirostat_v2 { +struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) { + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_v2_i, + /* .ctx = */ new llama_sampler_context_mirostat_v2 { /* .tau = */ tau, /* .eta = */ eta, /* .mu = */ 2.0f*tau, @@ -857,7 +926,7 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa // grammar -struct llama_constraint_context_grammar { +struct llama_sampler_context_grammar { const struct llama_vocab * vocab; std::string grammar_str; @@ -866,22 +935,22 @@ struct llama_constraint_context_grammar { struct llama_grammar * grammar; }; -static struct llama_constraint_i llama_constraint_grammar_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_accept_impl(*ctx->grammar, token); } }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { - llama_constraint_grammar_impl(cur_p, *ctx->grammar); + llama_sampler_grammar_impl(cur_p, *ctx->grammar); } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (!ctx->grammar) { return; } @@ -891,12 +960,12 @@ static struct llama_constraint_i llama_constraint_grammar_i = { llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_grammar *) smpl->ctx; - auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); - auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; + auto * ctx_dst = (llama_sampler_context_grammar *) result->ctx; if (ctx_src->grammar) { ctx_dst->grammar_str = ctx_src->grammar_str; ctx_dst->grammar_root = ctx_src->grammar_root; @@ -906,8 +975,8 @@ static struct llama_constraint_i llama_constraint_grammar_i = { return result; }, - /* .free = */ [](struct llama_constraint * cnstr) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); @@ -917,8 +986,8 @@ static struct llama_constraint_i llama_constraint_grammar_i = { }, }; -struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { - auto * ctx = new llama_constraint_context_grammar; +struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + auto * ctx = new llama_sampler_context_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { @@ -936,15 +1005,15 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ }; } - return new llama_constraint { - /* .iface = */ &llama_constraint_grammar_i, + return new llama_sampler { + /* .iface = */ &llama_sampler_grammar_i, /* .ctx = */ ctx, }; } // penalties -struct llama_constraint_context_penalties { +struct llama_sampler_context_penalties { const struct llama_vocab * vocab; const int32_t penalty_last_n; @@ -958,16 +1027,16 @@ struct llama_constraint_context_penalties { ring_buffer prev; }; -static struct llama_constraint_i llama_constraint_penalties_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; ctx->prev.push_back(token); }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary"); if (ctx->ignore_eos) { cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY; @@ -981,26 +1050,26 @@ static struct llama_constraint_i llama_constraint_penalties_i = { const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY; // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the constraint context + // TODO: optimize this by maintaining the token count in the sampler context llama_token_cnt token_count; for (int i = 0; i < ctx->penalty_last_n; ++i) { token_count[ctx->prev.rat(i)]++; } - llama_constraint_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); if (!ctx->penalize_nl) { // restore the logit of the newline token if it was penalized cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; ctx->prev.clear(); }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx; - auto * result = llama_constraint_init_penalties_impl( + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties_impl( *ctx_src->vocab, ctx_src->penalty_last_n, ctx_src->penalty_repeat, @@ -1009,23 +1078,23 @@ static struct llama_constraint_i llama_constraint_penalties_i = { ctx_src->penalize_nl, ctx_src->ignore_eos); - auto * ctx_dst = (llama_constraint_context_penalties *) result->ctx; + auto * ctx_dst = (llama_sampler_context_penalties *) result->ctx; ctx_dst->prev = ctx_src->prev; return result; }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_penalties *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_penalties *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { +struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); - return new llama_constraint { - /* .iface = */ &llama_constraint_penalties_i, - /* .ctx = */ new llama_constraint_context_penalties { + return new llama_sampler { + /* .iface = */ &llama_sampler_penalties_i, + /* .ctx = */ new llama_sampler_context_penalties { /* .vocab = */ &vocab, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, @@ -1040,100 +1109,156 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam // logit-bias -struct llama_constraint_context_logit_bias { +struct llama_sampler_context_logit_bias { const struct llama_vocab * vocab; std::vector logit_bias; }; -static struct llama_constraint_i llama_constraint_logit_bias_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; }, +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_logit_bias *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary"); for (const auto & lb : ctx->logit_bias) { cur_p->data[lb.token].logit += lb.bias; } }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx; - return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_logit_bias *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_logit_bias *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_logit_bias_impl( +struct llama_sampler * llama_sampler_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - return new llama_constraint { - /* .iface = */ &llama_constraint_logit_bias_i, - /* .ctx = */ new llama_constraint_context_logit_bias { + return new llama_sampler { + /* .iface = */ &llama_sampler_logit_bias_i, + /* .ctx = */ new llama_sampler_context_logit_bias { /* .vocab = */ &vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), }, }; } +// sampler chain + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + chain->n_sample++; + }, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply_impl(*smpl, cur_p); + } + }, + /* .reset = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset_impl(*smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; + }, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init_impl(chain_src->params); + + auto * chain_dst = (llama_sampler_chain *) result->ctx; + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl)); + } + + return result; + }, + /* .free = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free_impl(smpl); + } + + delete chain; + }, +}; + +struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) { + return new llama_sampler { + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }, + }; +} + +void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) { + chain.samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) { + if (i < 0 || i >= (int32_t) chain.samplers.size()) { + return nullptr; + } + + return chain.samplers[i]; +} + +int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) { + return chain.samplers.size(); +} + + //////////////////////////////////////// -struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr) { - return cnstr.iface->clone ? cnstr.iface->clone(&cnstr) : nullptr; -} - -void llama_constraint_free_impl(struct llama_constraint * cnstr) { - if (cnstr == nullptr) { - return; +const char * llama_sampler_name_impl(const struct llama_sampler & smpl) { + if (!smpl.iface) { + return "(null)"; } - if (cnstr->iface->free) { - cnstr->iface->free(cnstr); - } - - delete cnstr; + return smpl.iface->name(&smpl); } -void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token) { - if (cnstr.iface->accept) { - cnstr.iface->accept(&cnstr, token); +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { + if (smpl.iface->accept) { + smpl.iface->accept(&smpl, token); } } -void llama_constraint_apply_impl(struct llama_constraint & cnstr, struct llama_token_data_array * cur_p) { - GGML_ASSERT(cnstr.iface->apply); - cnstr.iface->apply(&cnstr, cur_p); +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { + GGML_ASSERT(smpl.iface->apply); + smpl.iface->apply(&smpl, cur_p); } -void llama_constraint_reset_impl(struct llama_constraint & cnstr) { - if (cnstr.iface->reset) { - cnstr.iface->reset(&cnstr); +void llama_sampler_reset_impl(struct llama_sampler & smpl) { + if (smpl.iface->reset) { + smpl.iface->reset(&smpl); } } -// -// samplers -// - -struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { - return new llama_sampler { - /* .params = */ params, - /* .vocab = */ &vocab, - - /* .rng = */ std::mt19937(params.seed), - - /* .prev = */ { (size_t) params.n_prev }, - /* .constraints = */ {}, - /* .cur = */ {}, - /* .cur_p = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - }; +struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { + return smpl.iface->clone ? smpl.iface->clone(&smpl) : nullptr; } void llama_sampler_free_impl(struct llama_sampler * smpl) { @@ -1141,129 +1266,9 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) { return; } - for (auto * cnstr : smpl->constraints) { - llama_constraint_free_impl(cnstr); + if (smpl->iface->free) { + smpl->iface->free(smpl); } delete smpl; } - -struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { - auto * result = new llama_sampler { - /* .params = */ smpl.params, - /* .vocab = */ smpl.vocab, - - /* .rng = */ smpl.rng, - - /* .prev = */ smpl.prev, - /* .constraints = */ {}, - /* .cur = */ {}, - /* .cur_p = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - }; - - // clone the constraints objects - result->constraints.clear(); - for (const auto & cnstr : smpl.constraints) { - if (cnstr->ctx == nullptr) { - result->constraints.push_back(new llama_constraint { - /* .iface = */ cnstr->iface, - /* .ctx = */ nullptr, - }); - } else { - GGML_ASSERT(cnstr->iface->clone); - result->constraints.push_back(cnstr->iface->clone(cnstr)); - } - } - - return result; -} - -void llama_sampler_reset_impl(struct llama_sampler & smpl) { - smpl.prev.clear(); - - for (auto * cnstr : smpl.constraints) { - llama_constraint_reset_impl(*cnstr); - } - - // TODO: should we reset the timings? -} - -const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) { - if (!cnstr.iface) { - return "(null)"; - } - - return cnstr.iface->name(&cnstr); -} - -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - smpl.prev.push_back(token); - - for (auto * cnstr : smpl.constraints) { - llama_constraint_accept_impl(*cnstr, token); - } -} - -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - for (auto * cnstr : smpl.constraints) { - llama_constraint_apply_impl(*cnstr, cur_p); - } -} - -void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { - smpl.constraints.push_back(cnstr); -} - -int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) { - return smpl.constraints.size(); -} - -struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.constraints.size()) { - return nullptr; - } - - return smpl.constraints[ith]; -} - -llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) { - switch (type) { - case LLAMA_SAMPLER_TYPE_GREEDY: - { - llama_constraint_softmax_impl(cur_p); - - return cur_p->data[0].id; - } - case LLAMA_SAMPLER_TYPE_DIST: - { - llama_constraint_softmax_impl(cur_p); - - std::vector probs(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs[i] = cur_p->data[i].p; - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - - return cur_p->data[idx].id; - } - default: - GGML_ABORT("invalid sampler type"); - } -} - -llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.prev.size()) { - return LLAMA_TOKEN_NULL; - } - - return smpl.prev.rat(ith); -} - -int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { - return smpl.prev.size(); -} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 18304b49a..3f14ec621 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -2,89 +2,26 @@ #include "llama-grammar.h" -#include #include struct llama_vocab; struct llama_grammar; -using llama_token_cnt = std::unordered_map; - -// TODO: tmp exposed until test-sampling is fixed -void llama_constraint_penalties_impl( - llama_token_data_array * cur_p, - const llama_token_cnt & token_count, - float penalty_repeat, - float penalty_freq, - float penalty_present); - -// constraints - -struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k); -struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); -struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_temp_impl (float t); -struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); - -struct llama_constraint * llama_constraint_init_mirostat_impl( - const struct llama_vocab & vocab, - float tau, - float eta, - int32_t m); - -struct llama_constraint * llama_constraint_init_mirostat_v2_impl( - float tau, - float eta); - -struct llama_constraint * llama_constraint_init_grammar_impl( - const struct llama_vocab & vocab, - const char * grammar_str, - const char * grammar_root); - -struct llama_constraint * llama_constraint_init_penalties_impl( - const struct llama_vocab & vocab, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos); - - LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias_impl( - const struct llama_vocab & vocab, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); - -struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr); - -void llama_constraint_free_impl(struct llama_constraint * cnstr); - -const char * llama_constraint_name_impl (const struct llama_constraint & cnstr); -void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token); -void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p); -void llama_constraint_reset_impl ( struct llama_constraint & cnstr); - // samplers -struct llama_sampler { - llama_sampler_params params; +const char * llama_sampler_name_impl (const struct llama_sampler & smpl); +void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); +void llama_sampler_reset_impl ( struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); +void llama_sampler_free_impl ( struct llama_sampler * smpl); - const struct llama_vocab * vocab; +// sampler chain - // state +struct llama_sampler_chain { + llama_sampler_chain_params params; - std::mt19937 rng; - - ring_buffer prev; - - std::vector constraints; - - std::vector cur; - - llama_token_data_array cur_p; + std::vector samplers; // timing @@ -93,18 +30,57 @@ struct llama_sampler { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); -void llama_sampler_free_impl ( struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); -void llama_sampler_reset_impl ( struct llama_sampler & smpl); -void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); +struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params); +void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i); +int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain); -void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); -int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); -struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); +using llama_token_cnt = std::unordered_map; -llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type); +// TODO: tmp exposed until test-sampling is fixed +void llama_sampler_penalties_impl( + llama_token_data_array * cur_p, + const llama_token_cnt & token_count, + float penalty_repeat, + float penalty_freq, + float penalty_present); -llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); -int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_init_greedy_impl (); +struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed); +struct llama_sampler * llama_sampler_init_softmax_impl (); +struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k); +struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep); +struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_temp_impl (float t); +struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent); + +struct llama_sampler * llama_sampler_init_mirostat_impl( + const struct llama_vocab & vocab, + float tau, + float eta, + int32_t m); + +struct llama_sampler * llama_sampler_init_mirostat_v2_impl( + float tau, + float eta); + +struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab & vocab, + const char * grammar_str, + const char * grammar_root); + +struct llama_sampler * llama_sampler_init_penalties_impl( + const struct llama_vocab & vocab, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos); + + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( + const struct llama_vocab & vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); diff --git a/src/llama.cpp b/src/llama.cpp index 2636f2316..df12de7ad 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -147,21 +147,6 @@ static void zeros(std::ofstream & file, size_t n) { } } -struct time_meas { - time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} - - ~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } - } - - const int64_t t_start_us; - - int64_t & t_acc; -}; - - LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -17937,11 +17922,8 @@ struct llama_context_params llama_context_default_params() { return result; } -struct llama_sampler_params llama_sampler_default_params() { - struct llama_sampler_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_prev =*/ 256, - /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, +struct llama_sampler_chain_params llama_sampler_chain_default_params() { + struct llama_sampler_chain_params result = { /*.no_timing =*/ false, // TODO: change to true and set explicitly in examples }; @@ -20610,98 +20592,24 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_constraint * llama_constraint_init_softmax(void) { - return llama_constraint_init_softmax_impl(); +const char * llama_sampler_name(const struct llama_sampler * smpl) { + return llama_sampler_name_impl(*smpl); } -struct llama_constraint * llama_constraint_init_top_k(int32_t k) { - return llama_constraint_init_top_k_impl(k); +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); } -struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { - return llama_constraint_init_top_p_impl(p, min_keep); +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + llama_sampler_apply_impl(*smpl, cur_p); } -struct llama_constraint * llama_constraint_init_min_p(float p, int32_t min_keep) { - return llama_constraint_init_min_p_impl(p, min_keep); +void llama_sampler_reset(struct llama_sampler * smpl) { + llama_sampler_reset_impl(*smpl); } -struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep) { - return llama_constraint_init_tail_free_impl(z, min_keep); -} - -struct llama_constraint * llama_constraint_init_typical(float p, int32_t min_keep) { - return llama_constraint_init_typical_impl(p, min_keep); -} - -struct llama_constraint * llama_constraint_init_temp(float temp) { - return llama_constraint_init_temp_impl(temp); -} - -struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta, float exponent) { - return llama_constraint_init_temp_ext_impl(temp, delta, exponent); -} - -struct llama_constraint * llama_constraint_init_mirostat(const struct llama_model * model, float tau, float eta) { - return llama_constraint_init_mirostat_impl(model->vocab, tau, eta, 100); -} - -struct llama_constraint * llama_constraint_init_mirostat_v2(float tau, float eta) { - return llama_constraint_init_mirostat_v2_impl(tau, eta); -} - -struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { - return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); -} - -struct llama_constraint * llama_constraint_init_penalties( - const struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); -} - -LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - const struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { - return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); -} - -struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr) { - return llama_constraint_clone_impl(*cnstr); -} - -void llama_constraint_free(struct llama_constraint * cnstr) { - if (cnstr == nullptr) { - return; - } - - llama_constraint_free_impl(cnstr); -} - -const char * llama_constraint_name(const struct llama_constraint * cnstr) { - return llama_constraint_name_impl(*cnstr); -} - -void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { - llama_constraint_accept_impl(*cnstr, token); -} - -void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - llama_constraint_apply_impl(*cnstr, cur_p); -} - -void llama_constraint_reset(struct llama_constraint * cnstr) { - llama_constraint_reset_impl(*cnstr); -} - -struct llama_sampler * llama_sampler_init(const struct llama_model * model, struct llama_sampler_params params) { - return llama_sampler_init_impl(model->vocab, params); +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + return llama_sampler_clone_impl(*smpl); } void llama_sampler_free(struct llama_sampler * smpl) { @@ -20712,86 +20620,110 @@ void llama_sampler_free(struct llama_sampler * smpl) { llama_sampler_free_impl(smpl); } -struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { - return llama_sampler_clone_impl(*smpl); +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return llama_sampler_chain_init_impl(params); } -void llama_sampler_reset(struct llama_sampler * smpl) { - llama_sampler_reset_impl(*smpl); +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl); } -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i); } -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us, smpl->params.no_timing); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - llama_sampler_apply_impl(*smpl, cur_p); +int llama_sampler_chain_n(const struct llama_sampler * chain) { + return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx); } -void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { - const int n_vocab = smpl->vocab->n_vocab; +struct llama_sampler * llama_sampler_init_greedy(void) { + return llama_sampler_init_greedy_impl(); +} - smpl->cur.resize(n_vocab); +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + return llama_sampler_init_dist_impl(seed); +} +struct llama_sampler * llama_sampler_init_softmax(void) { + return llama_sampler_init_softmax_impl(); +} + +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + return llama_sampler_init_top_k_impl(k); +} + +struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) { + return llama_sampler_init_top_p_impl(p, min_keep); +} + +struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) { + return llama_sampler_init_min_p_impl(p, min_keep); +} + +struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) { + return llama_sampler_init_tail_free_impl(z, min_keep); +} + +struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) { + return llama_sampler_init_typical_impl(p, min_keep); +} + +struct llama_sampler * llama_sampler_init_temp(float temp) { + return llama_sampler_init_temp_impl(temp); +} + +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + return llama_sampler_init_temp_ext_impl(temp, delta, exponent); +} + +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) { + return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100); +} + +struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) { + return llama_sampler_init_mirostat_v2_impl(tau, eta); +} + +struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { + return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); +} + +struct llama_sampler * llama_sampler_init_penalties( + const struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); +} + +LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + // TODO: do not allocate each time + std::vector cur(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + llama_sampler_apply(smpl, &cur_p); + + return cur_p.data[cur_p.selected].id; } -llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl) { - return &smpl->cur_p; -} - -void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) { - llama_sampler_constraint_add_impl(*smpl, cnstr); -} - -int llama_sampler_n_constraints (const struct llama_sampler * smpl) { - return llama_sampler_n_constraints_impl(*smpl); -} - -struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) { - return llama_sampler_constraint_get_impl(*smpl, i); -} - -llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us, smpl->params.no_timing); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type); - - smpl->n_sample++; - - return res; -} - -int llama_sampler_n_prev(const struct llama_sampler * smpl) { - return llama_sampler_n_prev_impl(*smpl); -} - -llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith) { - return llama_sampler_prev_impl(*smpl, ith); -} - -llama_token llama_sampler_last(const struct llama_sampler * smpl) { - return llama_sampler_prev_impl(*smpl, 0); -} - -//llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { -// GGML_ABORT("not implemented"); -//} - // // model split // @@ -20820,7 +20752,9 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { +void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain) { + auto * smpl = chain ? (const struct llama_sampler_chain *) chain->ctx : nullptr; + const llama_timings timings = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), @@ -20845,13 +20779,15 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * chain) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - if (smpl) { - smpl->t_sample_us = smpl->n_sample = 0; + if (chain) { + auto * smpl = (struct llama_sampler_chain *) chain->ctx; + + smpl->t_sample_us = smpl->n_sample = 0; } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 74bb4a3a3..adc1ff4e6 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -21,8 +21,8 @@ static void dump(const llama_token_data_array * cur_p) { #define APPLY(__cnstr, __cur_p) do { \ auto * cnstr = (__cnstr); \ - llama_constraint_apply(cnstr, (__cur_p)); \ - llama_constraint_free(cnstr); \ + llama_sampler_apply(cnstr, (__cur_p)); \ + llama_sampler_free(cnstr); \ } while(0) static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { @@ -35,10 +35,10 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; DUMP(&cur_p); - APPLY(llama_constraint_init_tail_free(z, 1), &cur_p); + APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -100,11 +100,11 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector