sampling : simplify sample API

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-04 21:23:35 +03:00
parent e7a11cac0e
commit 8e80a1cf6b
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 147 additions and 199 deletions

View file

@ -40,8 +40,9 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_params lparams = llama_sampler_default_params();
lparams.seed = params.seed;
lparams.n_prev = params.n_prev;
lparams.seed = params.seed;
lparams.n_prev = params.n_prev;
lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST;
auto * result = new gpt_sampler {
/* .params = */ params,
@ -61,39 +62,41 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
/* .smpl = */ llama_sampler_init(model, lparams)
};
if (params.mirostat == 0) {
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_MIN_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TFS_Z:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
default:
GGML_ASSERT(false && "unknown constraint type");
if (params.temp > 0.0f) {
if (params.mirostat == 0) {
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_MIN_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TFS_Z:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
default:
GGML_ASSERT(false && "unknown constraint type");
}
}
} else if (params.mirostat == 1) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
} else if (params.mirostat == 1) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
return result;
@ -151,45 +154,11 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) {
llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr);
}
static llama_token gpt_sampler_sample(
struct llama_sampler * smpl,
struct llama_token_data_array * cur_p,
float temp,
int n_probs) {
llama_token res = 0;
if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) {
// greedy sampling, with probs
res = llama_sampler_sample_greedy(smpl, cur_p, true);
} else if (temp == 0.0f) {
// greedy sampling, no probs
res = llama_sampler_sample_greedy(smpl, cur_p, false);
} else {
// apply all sampling constraints and then sample
llama_sampler_apply(smpl, cur_p);
res = llama_sampler_sample_dist(smpl, cur_p);
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// }
//}
//LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
}
return res;
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) {
return llama_sampler_sample(gsmpl->smpl, cur_p);
}
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) {
const auto & params = gsmpl->params;
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr;
@ -204,8 +173,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);
// first, sample the token without any grammar constraints
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs);
llama_sampler_apply(smpl, cur_p);
const llama_token id = llama_sampler_sample(smpl, cur_p);
// check if it the sampled token fits the grammar
{
@ -228,7 +198,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
llama_constraint_apply(pnlt, cur_p);
llama_constraint_apply(grmr, cur_p);
return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs);
llama_sampler_apply(smpl, cur_p);
return llama_sampler_sample(smpl, cur_p);
}
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
@ -237,14 +209,6 @@ void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_arra
llama_constraint_apply(gsmpl->grmr, cur_p);
}
llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) {
return llama_sampler_sample_dist(gsmpl->smpl, cur_p);
}
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs) {
return llama_sampler_sample_greedy(gsmpl->smpl, cur_p, probs);
}
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
auto & smpl = gsmpl->smpl;

View file

@ -73,15 +73,19 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl);
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits);
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p);
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
// common sampling implementation:
// extended sampling implementation:
//
// - set logits
// - apply the configured sampling constraints
@ -90,11 +94,6 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx);
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p);
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs);
// helpers
// print the constraints into a string

View file

@ -66,7 +66,7 @@ int main(int argc, char ** argv) {
auto sparams = llama_sampler_default_params();
sparams.seed = params.sparams.seed;
sparams.seed = params.sparams.seed;
llama_sampler * smpl = llama_sampler_init(model, sparams);
@ -177,9 +177,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits);
const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr);
//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

View file

@ -124,7 +124,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_sampler_set_logits(smpl, logits);
llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false);
llama_token token = llama_sampler_sample(smpl, nullptr);
if (token == eos_token) {
break;
}
@ -171,7 +171,11 @@ int main(int argc, char * argv[]) {
// create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams);
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic

View file

@ -83,7 +83,11 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// tokenize the prompt
std::vector<llama_token> tokens_list;
@ -221,7 +225,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits);
// sample the most likely token
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

View file

@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits);
auto next_token = llama_sampler_sample_dist(smpl, nullptr);
auto next_token = llama_sampler_sample(smpl, nullptr);
auto next_token_str = llama_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str());
@ -130,7 +130,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl2, logits);
auto next_token = llama_sampler_sample_dist(smpl2, nullptr);
auto next_token = llama_sampler_sample(smpl2, nullptr);
auto next_token_str = llama_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str());
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl3, logits);
auto next_token = llama_sampler_sample_dist(smpl3, nullptr);
auto next_token = llama_sampler_sample(smpl3, nullptr);
auto next_token_str = llama_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str());

View file

@ -55,7 +55,11 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
auto sparams = llama_sampler_default_params();
sparams.type = LLAMA_SAMPLER_TYPE_GREEDY;
llama_sampler * smpl = llama_sampler_init(model, sparams);
// tokenize the prompt
@ -117,7 +121,7 @@ int main(int argc, char ** argv) {
llama_sampler_set_logits(smpl, logits);
// sample the most likely token
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
const llama_token new_token_id = llama_sampler_sample(smpl, nullptr);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

View file

@ -182,6 +182,8 @@ int main(int argc, char ** argv) {
// target model sampling context (reuse the llama_context's sampling instance)
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
struct llama_constraint * softmax = llama_constraint_init_softmax();
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
@ -236,7 +238,7 @@ int main(int argc, char ** argv) {
auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
gpt_sampler_apply_grammar(smpl, &dist_tgt);
gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax
llama_constraint_apply(softmax, &dist_tgt);
float p_tgt = 0.0f;
float p_dft = 0.0f;
@ -335,11 +337,10 @@ int main(int argc, char ** argv) {
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = gpt_sampler_sample_dist(smpl, &dist_tgt);
token_id = gpt_sampler_sample(smpl, &dist_tgt);
gpt_sampler_accept(smpl, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
} else {
// greedy verification
@ -615,6 +616,7 @@ int main(int argc, char ** argv) {
gpt_sampler_free(drafts[s].smpl);
}
llama_constraint_free(softmax);
llama_batch_free(batch_dft);
llama_free(ctx_tgt);

View file

@ -370,8 +370,8 @@ extern "C" {
} llama_logit_bias;
enum llama_sampler_type {
LLAMA_SAMPLER_TYPE_GREEDY = 0,
LLAMA_SAMPLER_TYPE_DIST = 1,
LLAMA_SAMPLER_TYPE_GREEDY = 0,
LLAMA_SAMPLER_TYPE_DIST = 1,
};
typedef struct llama_sampler_params {
@ -1092,10 +1092,12 @@ extern "C" {
// samplers
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits);
@ -1107,11 +1109,7 @@ extern "C" {
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs);
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p);
/// @details Get the number of accepted tokens so far (max of n_prev)
LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl);

View file

@ -1183,6 +1183,20 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
// TODO: should we reset the timings?
}
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
smpl.prev.push_back(token);
for (auto * cnstr : smpl.constraints) {
llama_constraint_accept_impl(*cnstr, token);
}
}
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) {
for (auto * cnstr : smpl.constraints) {
llama_constraint_apply_impl(*cnstr, cur_p);
}
}
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
smpl.constraints.push_back(cnstr);
}
@ -1199,17 +1213,31 @@ struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_s
return smpl.constraints[ith];
}
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
smpl.prev.push_back(token);
llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) {
switch (type) {
case LLAMA_SAMPLER_TYPE_GREEDY:
{
llama_constraint_softmax_impl(cur_p);
for (auto * cnstr : smpl.constraints) {
llama_constraint_accept_impl(*cnstr, token);
}
}
return cur_p->data[0].id;
}
case LLAMA_SAMPLER_TYPE_DIST:
{
llama_constraint_softmax_impl(cur_p);
void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) {
for (auto * cnstr : smpl.constraints) {
llama_constraint_apply_impl(*cnstr, cur_p);
std::vector<float> probs(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
probs[i] = cur_p->data[i].p;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
return cur_p->data[idx].id;
}
default:
GGML_ABORT("invalid sampler type");
}
}
@ -1224,40 +1252,3 @@ llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith)
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
return smpl.prev.size();
}
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) {
if (probs) {
// if probs are needed, we apply softmax to get the probabilities
llama_constraint_softmax_impl(cur_p);
// the cur_p are sorted, so we can just return the first one
return cur_p->data[0].id;
}
// return the token with the highest logit
auto * max_iter = std::max_element(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
return result;
}
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng) {
llama_constraint_softmax_impl(cur_p);
std::vector<float> probs;
probs.reserve(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
probs.push_back(cur_p->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
llama_token result = cur_p->data[idx].id;
return result;
}

View file

@ -104,20 +104,18 @@ struct llama_sampler {
mutable int32_t n_sample;
};
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
void llama_sampler_free_impl ( struct llama_sampler * smpl);
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
void llama_sampler_reset_impl( struct llama_sampler & smpl);
struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params);
void llama_sampler_free_impl ( struct llama_sampler * smpl);
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
void llama_sampler_reset_impl ( struct llama_sampler & smpl);
void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type);
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs);
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng);

View file

@ -17939,7 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() {
struct llama_sampler_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_prev =*/ 256,
/*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY,
/*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
};
return result;
@ -20713,6 +20713,20 @@ void llama_sampler_reset(struct llama_sampler * smpl) {
llama_sampler_reset_impl(*smpl);
}
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
llama_sampler_accept_impl(*smpl, token);
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
llama_sampler_apply_impl(*smpl, cur_p);
}
void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) {
const int n_vocab = smpl->vocab->n_vocab;
@ -20741,42 +20755,14 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample
return llama_sampler_constraint_get_impl(*smpl, i);
}
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
llama_sampler_accept_impl(*smpl, token);
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
llama_sampler_apply_impl(*smpl, cur_p);
}
llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
auto res = llama_sampler_sample_greedy_impl(cur_p, probs);
smpl->n_sample++;
return res;
}
llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
}
auto res = llama_sampler_sample_dist_impl(cur_p, smpl->rng);
auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type);
smpl->n_sample++;