sampling : simplify sample API

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-04 21:23:35 +03:00
parent e7a11cac0e
commit 8e80a1cf6b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 147 additions and 199 deletions

View file

@ -40,8 +40,9 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_params lparams = llama_sampler_default_params(); llama_sampler_params lparams = llama_sampler_default_params();
lparams.seed = params.seed; lparams.seed = params.seed;
lparams.n_prev = params.n_prev; lparams.n_prev = params.n_prev;
lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST;
auto * result = new gpt_sampler { auto * result = new gpt_sampler {
/* .params = */ params, /* .params = */ params,
@ -61,39 +62,41 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
/* .smpl = */ llama_sampler_init(model, lparams) /* .smpl = */ llama_sampler_init(model, lparams)
}; };
if (params.mirostat == 0) { if (params.temp > 0.0f) {
for (const auto & cnstr : params.constraints) { if (params.mirostat == 0) {
switch (cnstr) { for (const auto & cnstr : params.constraints) {
case GPT_CONSTRAINT_TYPE_TOP_K: switch (cnstr) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); case GPT_CONSTRAINT_TYPE_TOP_K:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
case GPT_CONSTRAINT_TYPE_TOP_P: break;
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); case GPT_CONSTRAINT_TYPE_TOP_P:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
case GPT_CONSTRAINT_TYPE_MIN_P: break;
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); case GPT_CONSTRAINT_TYPE_MIN_P:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
case GPT_CONSTRAINT_TYPE_TFS_Z: break;
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); case GPT_CONSTRAINT_TYPE_TFS_Z:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
case GPT_CONSTRAINT_TYPE_TYPICAL_P: break;
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); case GPT_CONSTRAINT_TYPE_TYPICAL_P:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
case GPT_CONSTRAINT_TYPE_TEMPERATURE: break;
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); case GPT_CONSTRAINT_TYPE_TEMPERATURE:
break; llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
default: break;
GGML_ASSERT(false && "unknown constraint type"); default:
GGML_ASSERT(false && "unknown constraint type");
}
} }
} else if (params.mirostat == 1) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
} }
} else if (params.mirostat == 1) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
} }
return result; return result;
@ -151,45 +154,11 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) {
llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr);
} }
static llama_token gpt_sampler_sample( llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) {
struct llama_sampler * smpl, return llama_sampler_sample(gsmpl->smpl, cur_p);
struct llama_token_data_array * cur_p,
float temp,
int n_probs) {
llama_token res = 0;
if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) {
// greedy sampling, with probs
res = llama_sampler_sample_greedy(smpl, cur_p, true);
} else if (temp == 0.0f) {
// greedy sampling, no probs
res = llama_sampler_sample_greedy(smpl, cur_p, false);
} else {
// apply all sampling constraints and then sample
llama_sampler_apply(smpl, cur_p);
res = llama_sampler_sample_dist(smpl, cur_p);
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
}
return res;
} }
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) {
const auto & params = gsmpl->params;
auto & bias = gsmpl->bias; auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt; auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr; auto & grmr = gsmpl->grmr;
@ -204,8 +173,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(bias, cur_p); llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(pnlt, cur_p);
// first, sample the token without any grammar constraints llama_sampler_apply(smpl, cur_p);
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs);
const llama_token id = llama_sampler_sample(smpl, cur_p);
// check if it the sampled token fits the grammar // check if it the sampled token fits the grammar
{ {
@ -228,7 +198,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(pnlt, cur_p);
llama_constraint_apply(grmr, cur_p); llama_constraint_apply(grmr, cur_p);
return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs); llama_sampler_apply(smpl, cur_p);
return llama_sampler_sample(smpl, cur_p);
} }
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
@ -237,14 +209,6 @@ void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_arra
llama_constraint_apply(gsmpl->grmr, cur_p); llama_constraint_apply(gsmpl->grmr, cur_p);
} }
llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
return llama_sampler_sample_dist(gsmpl->smpl, cur_p);
}
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs) {
return llama_sampler_sample_greedy(gsmpl->smpl, cur_p, probs);
}
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
auto & smpl = gsmpl->smpl; auto & smpl = gsmpl->smpl;

View file

@ -73,15 +73,19 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl);
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl); void gpt_sampler_reset (struct gpt_sampler * gsmpl);
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits);
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p);
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
// common sampling implementation: // extended sampling implementation:
// //
// - set logits // - set logits
// - apply the configured sampling constraints // - apply the configured sampling constraints
@ -90,11 +94,6 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
// //
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx);
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs);
// helpers // helpers
// print the constraints into a string // print the constraints into a string

View file

@ -66,7 +66,7 @@ int main(int argc, char ** argv) {
auto sparams = llama_sampler_default_params(); auto sparams = llama_sampler_default_params();
sparams.seed = params.sparams.seed; sparams.seed = params.sparams.seed;
llama_sampler * smpl = llama_sampler_init(model, sparams); llama_sampler * smpl = llama_sampler_init(model, sparams);
@ -177,9 +177,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits); llama_sampler_set_logits(smpl, logits);
const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

View file

@ -124,7 +124,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_sampler_set_logits(smpl, logits); llama_sampler_set_logits(smpl, logits);
llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false); llama_token token = llama_sampler_sample(smpl, nullptr);
if (token == eos_token) { if (token == eos_token) {
break; break;
} }
@ -171,7 +171,11 @@ int main(int argc, char * argv[]) {
// create generation context // create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams); llama_context * ctx = llama_new_context_with_model(model, cparams);
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// ### Embedding/Representation ### // ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic // samples taken from: https://github.com/ContextualAI/gritlm#basic

View file

@ -83,7 +83,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
@ -221,7 +225,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits); llama_sampler_set_logits(smpl, logits);
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

View file

@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits); llama_sampler_set_logits(smpl, logits);
auto next_token = llama_sampler_sample_dist(smpl, nullptr); auto next_token = llama_sampler_sample(smpl, nullptr);
auto next_token_str = llama_token_to_piece(ctx, next_token); auto next_token_str = llama_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
@ -130,7 +130,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl2, logits); llama_sampler_set_logits(smpl2, logits);
auto next_token = llama_sampler_sample_dist(smpl2, nullptr); auto next_token = llama_sampler_sample(smpl2, nullptr);
auto next_token_str = llama_token_to_piece(ctx2, next_token); auto next_token_str = llama_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl3, logits); llama_sampler_set_logits(smpl3, logits);
auto next_token = llama_sampler_sample_dist(smpl3, nullptr); auto next_token = llama_sampler_sample(smpl3, nullptr);
auto next_token_str = llama_token_to_piece(ctx3, next_token); auto next_token_str = llama_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());

View file

@ -55,7 +55,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// tokenize the prompt // tokenize the prompt
@ -117,7 +121,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits); llama_sampler_set_logits(smpl, logits);
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

View file

@ -182,6 +182,8 @@ int main(int argc, char ** argv) {
// target model sampling context (reuse the llama_context's sampling instance) // target model sampling context (reuse the llama_context's sampling instance)
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
struct llama_constraint * softmax = llama_constraint_init_softmax();
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
@ -236,7 +238,7 @@ int main(int argc, char ** argv) {
auto & dist_tgt = *gpt_sampler_get_candidates(smpl); auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
gpt_sampler_apply_grammar(smpl, &dist_tgt); gpt_sampler_apply_grammar(smpl, &dist_tgt);
gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax llama_constraint_apply(softmax, &dist_tgt);
float p_tgt = 0.0f; float p_tgt = 0.0f;
float p_dft = 0.0f; float p_dft = 0.0f;
@ -335,11 +337,10 @@ int main(int argc, char ** argv) {
// all drafted tokens were rejected // all drafted tokens were rejected
// sample from the target model // sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n"); LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = gpt_sampler_sample_dist(smpl, &dist_tgt); token_id = gpt_sampler_sample(smpl, &dist_tgt);
gpt_sampler_accept(smpl, token_id, true); gpt_sampler_accept(smpl, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id); token_str = llama_token_to_piece(ctx_tgt, token_id);
} }
} else { } else {
// greedy verification // greedy verification
@ -615,6 +616,7 @@ int main(int argc, char ** argv) {
gpt_sampler_free(drafts[s].smpl); gpt_sampler_free(drafts[s].smpl);
} }
llama_constraint_free(softmax);
llama_batch_free(batch_dft); llama_batch_free(batch_dft);
llama_free(ctx_tgt); llama_free(ctx_tgt);

View file

@ -370,8 +370,8 @@ extern "C" {
} llama_logit_bias; } llama_logit_bias;
enum llama_sampler_type { enum llama_sampler_type {
LLAMA_SAMPLER_TYPE_GREEDY = 0, LLAMA_SAMPLER_TYPE_GREEDY = 0,
LLAMA_SAMPLER_TYPE_DIST = 1, LLAMA_SAMPLER_TYPE_DIST = 1,
}; };
typedef struct llama_sampler_params { typedef struct llama_sampler_params {
@ -1092,10 +1092,12 @@ extern "C" {
// samplers // samplers
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits);
@ -1107,11 +1109,7 @@ extern "C" {
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
/// @details Get the number of accepted tokens so far (max of n_prev) /// @details Get the number of accepted tokens so far (max of n_prev)
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);

View file

@ -1183,6 +1183,20 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
// TODO: should we reset the timings? // TODO: should we reset the timings?
} }
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
smpl.prev.push_back(token);
for (auto * cnstr : smpl.constraints) {
llama_constraint_accept_impl(*cnstr, token);
}
}
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) {
for (auto * cnstr : smpl.constraints) {
llama_constraint_apply_impl(*cnstr, cur_p);
}
}
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
smpl.constraints.push_back(cnstr); smpl.constraints.push_back(cnstr);
} }
@ -1199,17 +1213,31 @@ struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_s
return smpl.constraints[ith]; return smpl.constraints[ith];
} }
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) {
smpl.prev.push_back(token); switch (type) {
case LLAMA_SAMPLER_TYPE_GREEDY:
{
llama_constraint_softmax_impl(cur_p);
for (auto * cnstr : smpl.constraints) { return cur_p->data[0].id;
llama_constraint_accept_impl(*cnstr, token); }
} case LLAMA_SAMPLER_TYPE_DIST:
} {
llama_constraint_softmax_impl(cur_p);
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { std::vector<float> probs(cur_p->size);
for (auto * cnstr : smpl.constraints) { for (size_t i = 0; i < cur_p->size; ++i) {
llama_constraint_apply_impl(*cnstr, cur_p); probs[i] = cur_p->data[i].p;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
return cur_p->data[idx].id;
}
default:
GGML_ABORT("invalid sampler type");
} }
} }
@ -1224,40 +1252,3 @@ llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith)
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
return smpl.prev.size(); return smpl.prev.size();
} }
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) {
if (probs) {
// if probs are needed, we apply softmax to get the probabilities
llama_constraint_softmax_impl(cur_p);
// the cur_p are sorted, so we can just return the first one
return cur_p->data[0].id;
}
// return the token with the highest logit
auto * max_iter = std::max_element(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
return result;
}
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng) {
llama_constraint_softmax_impl(cur_p);
std::vector<float> probs;
probs.reserve(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
probs.push_back(cur_p->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
llama_token result = cur_p->data[idx].id;
return result;
}

View file

@ -104,20 +104,18 @@ struct llama_sampler {
mutable int32_t n_sample; mutable int32_t n_sample;
}; };
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
void llama_sampler_free_impl ( struct llama_sampler * smpl); void llama_sampler_free_impl ( struct llama_sampler * smpl);
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
void llama_sampler_reset_impl( struct llama_sampler & smpl); void llama_sampler_reset_impl ( struct llama_sampler & smpl);
void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type);
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs);
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng);

View file

@ -17939,7 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() {
struct llama_sampler_params result = { struct llama_sampler_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED, /*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_prev =*/ 256, /*.n_prev =*/ 256,
/*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY, /*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
}; };
return result; return result;
@ -20713,6 +20713,20 @@ void llama_sampler_reset(struct llama_sampler * smpl) {
llama_sampler_reset_impl(*smpl); llama_sampler_reset_impl(*smpl);
} }
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
llama_sampler_accept_impl(*smpl, token);
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
llama_sampler_apply_impl(*smpl, cur_p);
}
void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) {
const int n_vocab = smpl->vocab->n_vocab; const int n_vocab = smpl->vocab->n_vocab;
@ -20741,42 +20755,14 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample
return llama_sampler_constraint_get_impl(*smpl, i); return llama_sampler_constraint_get_impl(*smpl, i);
} }
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
llama_sampler_accept_impl(*smpl, token);
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us); time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) { if (cur_p == nullptr) {
cur_p = &smpl->cur_p; cur_p = &smpl->cur_p;
} }
llama_sampler_apply_impl(*smpl, cur_p); auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type);
}
llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
auto res = llama_sampler_sample_greedy_impl(cur_p, probs);
smpl->n_sample++;
return res;
}
llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
auto res = llama_sampler_sample_dist_impl(cur_p, smpl->rng);
smpl->n_sample++; smpl->n_sample++;