cont : use new API in examples

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-04 13:54:32 +03:00
parent 437376e708
commit a0b91214b4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
21 changed files with 387 additions and 809 deletions

View file

@ -2,6 +2,16 @@
#include "common.h"
struct gpt_sampler {
gpt_sampler_params params;
struct llama_constraint * bias;
struct llama_constraint * pnlt;
struct llama_constraint * grmr;
struct llama_sampler * smpl;
};
std::string gpt_sampler_params::print_all() const {
char result[1024];
@ -33,8 +43,6 @@ std::string gpt_sampler_params::print_constraints() const {
}
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
gpt_sampler * result = new gpt_sampler();
llama_sampler_params lparams = llama_sampler_default_params();
lparams.seed = params.seed;
@ -43,21 +51,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
lparams.mirostat_tau = params.mirostat_tau;
lparams.mirostat_eta = params.mirostat_eta;
result->smpl = llama_sampler_init(model, lparams);
llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()));
llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos));
auto * result = new gpt_sampler {
.params = params,
.bias = llama_constraint_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()),
.pnlt = llama_constraint_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos),
.grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"),
.smpl = llama_sampler_init(model, lparams)
};
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
@ -84,14 +94,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
}
}
result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root");
return result;
}
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
if (gsmpl) {
llama_constraint_free(gsmpl->bias);
llama_constraint_free(gsmpl->pnlt);
llama_constraint_free(gsmpl->grmr);
llama_sampler_free(gsmpl->smpl);
delete gsmpl;
@ -121,18 +132,28 @@ void gpt_sampler_reset (struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->smpl);
}
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) {
llama_sampler_set_logits(gsmpl->smpl, logits);
}
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
return llama_sampler_get_candidates(gsmpl->smpl);
}
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
return llama_sampler_last(gsmpl->smpl);
}
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) {
llama_print_timings(ctx, gsmpl->smpl);
}
static llama_token gpt_sampler_sample(
struct llama_sampler * smpl,
struct llama_token_data_array * cur_p,
float temp,
int mirostat,
int n_probs) {
GGML_ASSERT(cur_p != nullptr && "candidates array must be provided");
llama_token res = 0;
if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) {
@ -142,6 +163,7 @@ static llama_token gpt_sampler_sample(
// greedy sampling, no probs
res = llama_sampler_sample_greedy(smpl, cur_p, false);
} else {
// apply all sampling constraints and then sample
llama_sampler_apply(smpl, cur_p);
if (mirostat != 0) {
@ -167,42 +189,62 @@ static llama_token gpt_sampler_sample(
return res;
}
llama_token gpt_sampler_sample(
struct gpt_sampler * gsmpl,
struct llama_context * ctx,
int idx) {
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) {
const auto & params = gsmpl->params;
auto & bias = gsmpl->bias;
auto & pnlt = gsmpl->pnlt;
auto & grmr = gsmpl->grmr;
auto & smpl = gsmpl->smpl;
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
auto * cur_p = llama_sampler_get_candidates(smpl);
// first, sample the token without any grammar constraints
const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
// create an array with a single token data element for the sampled id
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
llama_constraint_apply(grmr, &single_token_data_array);
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
// if the token is not valid, sample again, after applying the grammar constraints
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);
// first, sample the token without any grammar constraints
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
// check if it the sampled token fits the grammar
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
llama_constraint_apply(grmr, &single_token_data_array);
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// if the token is not valid, sample again, first apply the grammar constraints and then sample
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
llama_constraint_apply(bias, cur_p);
llama_constraint_apply(pnlt, cur_p);
llama_constraint_apply(grmr, cur_p);
return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
}
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) {
GGML_ASSERT(candidates != nullptr);
llama_constraint_apply(gsmpl->grmr, candidates);
}
llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) {
return llama_sampler_sample_dist(gsmpl->smpl, candidates);
}
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs) {
return llama_sampler_sample_greedy(gsmpl->smpl, candidates, probs);
}
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
auto & smpl = gsmpl->smpl;

View file

@ -60,15 +60,12 @@ struct gpt_sampler_params {
std::string print_constraints() const;
};
struct gpt_sampler {
gpt_sampler_params params;
struct llama_constraint * grmr = nullptr;
struct llama_sampler * smpl = nullptr;
};
// llama_sampler API overload
// gpt_sampler extends llama_sampler with additional functionality:
//
// - grammar support
// - custom sampler logic based on the paramerters
//
struct gpt_sampler;
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
@ -79,8 +76,14 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl);
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits);
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
// common sampling implementation:
//
// - set logits
@ -88,10 +91,12 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
llama_token gpt_sampler_sample(
struct gpt_sampler * gsmpl,
struct llama_context * ctx,
int idx);
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx);
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates);
llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * candidates);
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs);
// helpers

View file

@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
auto sparams = llama_sampling_default_params();
auto sparams = llama_sampler_default_params();
sparams.seed = params.sparams.seed;
sparams.top_k = 40;
sparams.top_p = 0.9f;
sparams.temp = 0.4f;
llama_sampling * smpl = llama_sampling_init(model, sparams);
llama_sampler * smpl = llama_sampler_init(model, sparams);
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p));
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -174,15 +175,11 @@ int main(int argc, char ** argv) {
const auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
llama_sampling_set_logits(smpl, logits);
llama_sampler_set_logits(smpl, logits);
llama_sampling_top_k(smpl, nullptr);
llama_sampling_top_p(smpl, nullptr);
llama_sampling_temp (smpl, nullptr);
const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr);
const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr);
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@ -246,7 +243,7 @@ int main(int argc, char ** argv) {
llama_batch_free(batch);
llama_sampling_free(smpl);
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);

View file

@ -92,7 +92,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
return result;
}
static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) {
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
std::string result;
const llama_model * model = llama_get_model(ctx);
@ -122,9 +122,9 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st
const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
llama_sampling_set_logits(smpl, logits);
llama_sampler_set_logits(smpl, logits);
llama_token token = llama_sampling_sample_greedy(smpl, nullptr);
llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false);
if (token == eos_token) {
break;
}
@ -171,7 +171,7 @@ int main(int argc, char * argv[]) {
// create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams);
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@ -212,7 +212,7 @@ int main(int argc, char * argv[]) {
std::string response = generate(ctx, smpl, prompt, true);
}
llama_sampling_free(smpl);
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();

View file

@ -33,7 +33,7 @@
static llama_context ** g_ctx;
static llama_model ** g_model;
static llama_sampling ** g_smpl;
static gpt_sampler ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx, *g_smpl);
gpt_print_timings(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
llama_sampling * smpl = nullptr;
gpt_sampler * smpl = nullptr;
g_model = &model;
g_ctx = &ctx;
@ -345,7 +345,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
smpl = llama_sampling_init(model, sparams);
smpl = gpt_sampler_init(model, sparams);
while (n_remain != 0 || params.interactive) {
// predict
@ -417,9 +417,9 @@ int main(int argc, char ** argv) {
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
llama_sampling_accept(smpl, id, true);
gpt_sampler_accept(smpl, id, true);
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
@ -440,7 +440,7 @@ int main(int argc, char ** argv) {
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(smpl, embd_inp[n_consumed], false);
gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
@ -472,7 +472,7 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
@ -538,7 +538,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG("found EOS token\n");
if (params.interactive) {
@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(smpl);
gpt_sampler_reset(smpl);
}
is_interacting = false;
}
@ -634,13 +634,13 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
llama_print_timings(ctx, smpl);
gpt_print_timings(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
llama_free(ctx);
llama_free_model(model);
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS

View file

@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
return true;
}
static const char * sample(struct llama_sampling * smpl,
static const char * sample(struct gpt_sampler * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
llama_sampling_accept(smpl, id, true);
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
gpt_sampler_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG_TEE("\n");
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
fflush(stdout);
}
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
printf("\n");
}

View file

@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
}
static const char * sample(struct llama_sampling * smpl,
static const char * sample(struct gpt_sampler * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
llama_sampling_accept(smpl, id, true);
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
gpt_sampler_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
return ctx_llava;
}
static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt;
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (!is_first) {
@ -238,11 +238,11 @@ static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_
LOG_TEE("\n");
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
return smpl;
}
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){
static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
return tmp;
@ -296,7 +296,7 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
}else {
while (true) {
LOG_TEE("<user>");
@ -315,7 +315,7 @@ int main(int argc, char ** argv) {
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);
}
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
}
}
printf("\n");

View file

@ -117,7 +117,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
@ -158,9 +158,9 @@ int main(int argc, char ** argv) {
// sample first token
{
id = llama_sampling_sample(smpl, ctx, 0);
id = gpt_sampler_sample(smpl, ctx, 0);
llama_sampling_accept(smpl, id, true);
gpt_sampler_accept(smpl, id, true);
{
const std::string token_str = llama_token_to_piece(ctx, id);
@ -283,9 +283,9 @@ int main(int argc, char ** argv) {
}
// sample the next token
id = llama_sampling_sample(smpl, ctx, i_batch);
id = gpt_sampler_sample(smpl, ctx, i_batch);
llama_sampling_accept(smpl, id, true);
gpt_sampler_accept(smpl, id, true);
// print
{
@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
@ -467,10 +467,11 @@ int main(int argc, char ** argv) {
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_accept = %d\n", n_accept);
llama_print_timings(ctx, smpl);
gpt_print_timings(ctx, smpl);
gpt_sampler_free(smpl);
llama_kv_cache_view_free(&kvc_view);
llama_sampling_free(smpl);
llama_batch_free(batch);

View file

@ -104,7 +104,7 @@ int main(int argc, char ** argv){
bool has_eos = false;
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
std::vector<llama_token> draft;
@ -128,9 +128,9 @@ int main(int argc, char ** argv){
int i_dft = 0;
while (true) {
// sample from the target model
llama_token id = llama_sampling_sample(smpl, ctx, i_dft);
llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
llama_sampling_accept(smpl, id, true);
gpt_sampler_accept(smpl, id, true);
const std::string token_str = llama_token_to_piece(ctx, id);
@ -239,9 +239,10 @@ int main(int argc, char ** argv){
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx, smpl);
gpt_print_timings(ctx, smpl);
gpt_sampler_free(smpl);
llama_sampling_free(smpl);
llama_batch_free(batch_tgt);
llama_free(ctx);

View file

@ -106,7 +106,7 @@ static void sigint_handler(int signo) {
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx, (*g_smpl)->smpl);
gpt_print_timings(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
@ -928,7 +928,7 @@ int main(int argc, char ** argv) {
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_print_timings(ctx, smpl->smpl);
gpt_print_timings(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
gpt_sampler_free(smpl);

View file

@ -51,7 +51,7 @@ static std::vector<std::string> k_prompts = {
struct client {
~client() {
if (smpl) {
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
}
}
@ -72,7 +72,7 @@ struct client {
std::string prompt;
std::string response;
struct llama_sampling * smpl = nullptr;
struct gpt_sampler * smpl = nullptr;
};
static void print_date_time() {
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.smpl = llama_sampling_init(model, params.sparams);
client.smpl = gpt_sampler_init(model, params.sparams);
}
std::vector<llama_token> tokens_system;
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
client.prompt = client.input + "\nAssistant:";
client.response = "";
llama_sampling_reset(client.smpl);
gpt_sampler_reset(client.smpl);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i);
const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
llama_sampling_accept(client.smpl, id, true);
gpt_sampler_accept(client.smpl, id, true);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients

View file

@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
// tokenize the prompt
std::vector<llama_token> tokens_list;
@ -218,10 +218,10 @@ int main(int argc, char ** argv) {
{
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
llama_sampling_set_logits(smpl, logits);
llama_sampler_set_logits(smpl, logits);
// sample the most likely token
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
@ -262,9 +262,10 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
llama_sampler_free(smpl);
llama_batch_free(batch);
llama_sampling_free(smpl);
llama_free(ctx);
llama_free_model(model);

View file

@ -38,10 +38,10 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling_params sparams = llama_sampling_default_params();
llama_sampler_params sparams = llama_sampler_default_params();
sparams.seed = params.sparams.seed;
llama_sampling * smpl = llama_sampling_init(model, sparams);
llama_sampler * smpl = llama_sampler_init(model, sparams);
// tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true);
@ -71,9 +71,9 @@ int main(int argc, char ** argv) {
for (auto i = 0; i < params.n_predict; i++) {
const auto * logits = llama_get_logits(ctx);
llama_sampling_set_logits(smpl, logits);
llama_sampler_set_logits(smpl, logits);
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
auto next_token = llama_sampler_sample_dist(smpl, nullptr);
auto next_token_str = llama_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str());
@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
// make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl2 = llama_sampling_init(model, sparams);
llama_sampler * smpl2 = llama_sampler_init(model, sparams);
printf("\nsecond run: %s", params.prompt.c_str());
@ -128,9 +128,9 @@ int main(int argc, char ** argv) {
for (auto i = 0; i < params.n_predict; i++) {
const auto * logits = llama_get_logits(ctx2);
llama_sampling_set_logits(smpl2, logits);
llama_sampler_set_logits(smpl2, logits);
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
auto next_token = llama_sampler_sample_dist(smpl2, nullptr);
auto next_token_str = llama_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str());
@ -157,7 +157,7 @@ int main(int argc, char ** argv) {
// make new context
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl3 = llama_sampling_init(model, sparams);
llama_sampler * smpl3 = llama_sampler_init(model, sparams);
printf("\nsingle seq run: %s", params.prompt.c_str());
@ -217,9 +217,9 @@ int main(int argc, char ** argv) {
for (auto i = 0; i < params.n_predict; i++) {
const auto * logits = llama_get_logits(ctx3);
llama_sampling_set_logits(smpl3, logits);
llama_sampler_set_logits(smpl3, logits);
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
auto next_token = llama_sampler_sample_dist(smpl3, nullptr);
auto next_token_str = llama_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str());
@ -236,9 +236,9 @@ int main(int argc, char ** argv) {
printf("\n");
llama_sampling_free(smpl);
llama_sampling_free(smpl2);
llama_sampling_free(smpl3);
llama_sampler_free(smpl);
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);
llama_free(ctx3);
llama_free_model(model);

View file

@ -170,10 +170,10 @@ struct server_slot {
// sampling
json json_schema;
struct gpt_sampling_params sparams;
struct gpt_sampler_params sparams;
struct gpt_sampler * smpl = nullptr;
llama_token sampled;
llama_sampling * smpl = nullptr;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
@ -653,7 +653,7 @@ struct server_context {
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.smpl != nullptr) {
llama_sampling_free(slot.smpl);
gpt_sampler_free(slot.smpl);
}
}
@ -1027,26 +1027,26 @@ struct server_context {
}
{
const auto & samplers = data.find("samplers");
if (samplers != data.end() && samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & sampler_name : *samplers) {
if (sampler_name.is_string()) {
sampler_names.emplace_back(sampler_name);
const auto & constraints = data.find("samplers");
if (constraints != data.end() && constraints->is_array()) {
std::vector<std::string> constraint_names;
for (const auto & name : *constraints) {
if (name.is_string()) {
constraint_names.emplace_back(name);
}
}
slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false);
slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false);
} else {
slot.sparams.samplers = default_sparams.samplers;
slot.sparams.constraints = default_sparams.constraints;
}
}
{
if (slot.smpl != nullptr) {
llama_sampling_free(slot.smpl);
gpt_sampler_free(slot.smpl);
}
slot.smpl = llama_sampling_init(model, slot.sparams);
slot.smpl = gpt_sampler_init(model, slot.sparams);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@ -1253,10 +1253,10 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
std::vector<std::string> samplers;
samplers.reserve(slot.sparams.samplers.size());
for (const auto & sampler : slot.sparams.samplers) {
samplers.emplace_back(llama_sampling_type_to_str(sampler));
std::vector<std::string> constraints;
constraints.reserve(slot.sparams.constraints.size());
for (const auto & constraint : slot.sparams.constraints) {
constraints.emplace_back(gpt_constraint_type_to_str(constraint));
}
return json {
@ -1290,7 +1290,7 @@ struct server_context {
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers},
{"samplers", constraints},
};
}
@ -2084,7 +2084,7 @@ struct server_context {
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
llama_sampling_reset(slot.smpl);
gpt_sampler_reset(slot.smpl);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
@ -2097,7 +2097,7 @@ struct server_context {
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false);
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
}
}
@ -2150,7 +2150,7 @@ struct server_context {
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(slot.smpl);
gpt_sampler_reset(slot.smpl);
}
// remove the non-common part from the cache
@ -2332,9 +2332,9 @@ struct server_context {
}
completion_token_output result;
const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i);
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
llama_sampling_accept(slot.smpl, id, true);
gpt_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
@ -2345,7 +2345,7 @@ struct server_context {
result.tok = id;
const auto * cur_p = llama_sampling_get_candidates(slot.smpl);
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
// TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643
// fix if necessary

View file

@ -55,7 +55,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
// tokenize the prompt
@ -114,10 +114,10 @@ int main(int argc, char ** argv) {
{
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
llama_sampling_set_logits(smpl, logits);
llama_sampler_set_logits(smpl, logits);
// sample the most likely token
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@ -159,8 +159,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_sampling_free(smpl);
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);

View file

@ -21,7 +21,7 @@ struct seq_draft {
std::vector<llama_token> tokens;
std::vector<std::vector<llama_token_data>> dists;
struct llama_sampling * smpl;
struct gpt_sampler * smpl = nullptr;
};
int main(int argc, char ** argv) {
@ -180,14 +180,14 @@ int main(int argc, char ** argv) {
bool has_eos = false;
// target model sampling context (reuse the llama_context's sampling instance)
struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams);
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
for (int s = 0; s < n_seq_dft; ++s) {
// allocate llama_sampling for each draft sequence
drafts[s].smpl = llama_sampling_init(model_dft, params.sparams);
// allocate gpt_sampler for each draft sequence
drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
}
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@ -229,13 +229,14 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sparams.temp > 0) {
// stochastic verification
const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]));
gpt_sampler_set_logits(smpl, logits);
auto & dist_tgt = *llama_sampling_get_candidates(smpl);
auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
llama_sampling_grammar(smpl, &dist_tgt);
llama_sampling_softmax(smpl, &dist_tgt);
gpt_sampler_apply_grammar(smpl, &dist_tgt);
gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax
float p_tgt = 0.0f;
float p_dft = 0.0f;
@ -280,7 +281,7 @@ int main(int argc, char ** argv) {
accept = true;
token_id = drafts[s].tokens[i_dft];
token_str = llama_token_to_piece(ctx_tgt, token_id);
llama_sampling_accept(smpl, token_id, true);
gpt_sampler_accept(smpl, token_id, true);
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
break;
@ -334,8 +335,8 @@ int main(int argc, char ** argv) {
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = llama_sampling_sample_dist(smpl, &dist_tgt);
llama_sampling_accept(smpl, token_id, true);
token_id = gpt_sampler_sample_dist(smpl, &dist_tgt);
gpt_sampler_accept(smpl, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
@ -344,9 +345,9 @@ int main(int argc, char ** argv) {
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(smpl, token_id, true);
gpt_sampler_accept(smpl, token_id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
@ -436,7 +437,10 @@ int main(int argc, char ** argv) {
break;
}
llama_sampling_cp(smpl, drafts[0].smpl);
if (drafts[0].smpl) {
gpt_sampler_free(drafts[0].smpl);
}
drafts[0].smpl = gpt_sampler_cp(smpl);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
@ -465,9 +469,9 @@ int main(int argc, char ** argv) {
continue;
}
llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl);
const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
@ -505,7 +509,11 @@ int main(int argc, char ** argv) {
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl);
if (drafts[n_seq_cur].smpl) {
gpt_sampler_free(drafts[n_seq_cur].smpl);
}
drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl);
sa.push_back(n_seq_cur);
@ -521,7 +529,7 @@ int main(int argc, char ** argv) {
const int s = sa[is];
llama_sampling_accept(drafts[s].smpl, id, true);
gpt_sampler_accept(drafts[s].smpl, id, true);
drafts[s].tokens.push_back(id);
// save cur_p.data into drafts[s].dists
@ -597,14 +605,14 @@ int main(int argc, char ** argv) {
LOG_TEE("\ndraft:\n");
// TODO: print sampling/grammar timings for all drafts
llama_print_timings(ctx_dft, nullptr);
gpt_print_timings(ctx_dft, nullptr);
LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt, smpl);
gpt_print_timings(ctx_tgt, smpl);
llama_sampling_free(smpl);
gpt_sampler_free(smpl);
for (int s = 0; s < n_seq_dft; ++s) {
llama_sampling_free(drafts[s].smpl);
gpt_sampler_free(drafts[s].smpl);
}
llama_batch_free(batch_dft);

View file

@ -46,9 +46,6 @@
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
// TODO: remove before merge
#define LLAMA_MAX_SAMPLERS 16
#ifdef __cplusplus
extern "C" {
#endif
@ -1001,133 +998,6 @@ extern "C" {
//
// Sampling API
// TODO: remove before merge
//
// TODO: llama_model should become llama_vocab
//LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
//LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
//// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
//LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
//// - clear prev token
//// - reset grammar state
//LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
//// Sampling parameter mutation
//// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
//LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
//LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
//// Set the logits from which to sample.
//// This call initializes the internal token candidates array.
//// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
//LLAMA_API void llama_sampling_set_logits(
// struct llama_sampling * smpl,
// const float * logits);
///// @details Returns the current candidate tokens.
//LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
// struct llama_sampling * smpl);
//// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
//// Each function can accept an array of token candidates. If the candidates are not provided, the internal
//// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
//LLAMA_API void llama_sampling_softmax(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
//LLAMA_API void llama_sampling_top_k(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
//LLAMA_API void llama_sampling_top_p(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
//LLAMA_API void llama_sampling_min_p(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
//LLAMA_API void llama_sampling_tail_free(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
//LLAMA_API void llama_sampling_typical(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Apply temperature and entropy
//LLAMA_API void llama_sampling_temp(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Apply constraints from grammar
//LLAMA_API void llama_sampling_grammar(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
//LLAMA_API void llama_sampling_penalties(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
//LLAMA_API llama_token llama_sampling_sample_mirostat(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Selects the token with the highest probability.
///// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
//LLAMA_API llama_token llama_sampling_sample_greedy(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Randomly selects a token from the candidates based on their probability distribution.
//LLAMA_API llama_token llama_sampling_sample_dist(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
//LLAMA_API llama_token llama_sampling_sample(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates);
///// @details Accepts the sampled token into the sampling context.
///// - adds it to "prev" tokens
///// - updates the grammar state (if apply_grammar is true)
//LLAMA_API void llama_sampling_accept(
// struct llama_sampling * smpl,
// llama_token token,
// bool apply_grammar);
///// @details Get the number of accepted tokens so far (max of n_prev)
//LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
///// @details Get the ith accepted token
///// @param ith [0, n_prev), ith == 0 is the last accepted token.
///// returns LLAMA_TOKEN_NULL if ith is out of bounds
//LLAMA_API llama_token llama_sampling_prev(
// const struct llama_sampling * smpl,
// int32_t ith);
///// @details Get the last accepted token
///// Same as llama_sampling_prev(smpl, 0)
///// returns LLAMA_TOKEN_NULL if there are no accepted tokens
//LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
//
// Sampling v2 API
//
// - Constraints
// The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their
@ -1203,7 +1073,7 @@ extern "C" {
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
// do not call if used with llama_sampler_add_constraint
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
@ -1221,11 +1091,7 @@ extern "C" {
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
// TODO: should this take ownership so the user does not need to call llama_constraint_free
// or should just make a reference to the constraint so that it can be reused in multiple llama_sampler?
//
// seems better to take the ownership, otherwise the copying of the sampler will be more complicated
// important: takes ownership of the constraint object and will free it in llama_sampler_free
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);

View file

@ -421,113 +421,8 @@ void llama_constraint_penalties_impl(
candidates->sorted = false;
}
llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
llama_constraint_softmax_impl(candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
float sum_ti_bi = 0.0;
float sum_ti_sq = 0.0;
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
float t_i = logf(float(i + 2) / float(i + 1));
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
sum_ti_bi += t_i * b_i;
sum_ti_sq += t_i * t_i;
}
s_hat = sum_ti_bi / sum_ti_sq;
// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
llama_constraint_top_k_impl(candidates, int(k), 1);
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
mu = mu - eta * e;
return X;
}
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
llama_constraint_softmax_impl(candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > mu;
}));
if (candidates->size == 0) {
candidates->size = 1;
}
// Normalize the probabilities of the remaining words
llama_constraint_softmax_impl(candidates);
// Sample the next word X from the remaining words
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
mu = mu - eta * e;
return X;
}
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) {
if (probs) {
// if probs are needed, we apply softmax to get the probabilities
llama_constraint_softmax_impl(candidates);
// the candidates are sorted, so we can just return the first one
return candidates->data[0].id;
}
// return the token with the highest logit
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
return result;
}
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
llama_constraint_softmax_impl(candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
probs.push_back(candidates->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
llama_token result = candidates->data[idx].id;
return result;
}
//
// sampling v2
// sampling
//
// constraints
@ -1172,3 +1067,107 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
return smpl.prev.size();
}
llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
llama_constraint_softmax_impl(candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
float sum_ti_bi = 0.0;
float sum_ti_sq = 0.0;
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
float t_i = logf(float(i + 2) / float(i + 1));
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
sum_ti_bi += t_i * b_i;
sum_ti_sq += t_i * t_i;
}
s_hat = sum_ti_bi / sum_ti_sq;
// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
llama_constraint_top_k_impl(candidates, int(k), 1);
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
mu = mu - eta * e;
return X;
}
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
llama_constraint_softmax_impl(candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > mu;
}));
if (candidates->size == 0) {
candidates->size = 1;
}
// Normalize the probabilities of the remaining words
llama_constraint_softmax_impl(candidates);
// Sample the next word X from the remaining words
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
mu = mu - eta * e;
return X;
}
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) {
if (probs) {
// if probs are needed, we apply softmax to get the probabilities
llama_constraint_softmax_impl(candidates);
// the candidates are sorted, so we can just return the first one
return candidates->data[0].id;
}
// return the token with the highest logit
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit;
});
llama_token result = max_iter->id;
return result;
}
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
llama_constraint_softmax_impl(candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
probs.push_back(candidates->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
const int idx = dist(rng);
llama_token result = candidates->data[idx].id;
return result;
}

View file

@ -10,6 +10,7 @@ struct llama_grammar;
using llama_token_cnt = std::unordered_map<llama_token, int>;
// TODO: tmp exposed, until tests start using llama_constraint
void llama_constraint_softmax_impl (struct llama_token_data_array * candidates);
void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep);
@ -27,30 +28,6 @@ void llama_constraint_penalties_impl(
float penalty_freq,
float penalty_present);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs);
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
//
// sampling v2
//
// constraints
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
@ -128,3 +105,21 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs);
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);

View file

@ -20609,346 +20609,6 @@ int32_t llama_chat_apply_template(
// sampling
//
//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) {
// return llama_sampling_init_impl(model->vocab, params);
//}
//void llama_sampling_free(struct llama_sampling * smpl) {
// if (smpl == nullptr) {
// return;
// }
// llama_sampling_free_impl(smpl);
//}
//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
// return llama_sampling_cp_impl(*smpl);
//}
//void llama_sampling_reset(struct llama_sampling * smpl) {
// llama_sampling_reset_impl(*smpl);
//}
//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
//}
//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
//}
//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) {
// const int n_vocab = smpl->vocab.n_vocab;
// smpl->cur.resize(n_vocab);
// for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
// }
// for (const auto & lb : smpl->logit_bias) {
// smpl->cur[lb.token].logit += lb.bias;
// }
// if (smpl->params.ignore_eos) {
// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY;
// }
// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false };
// // apply penalties
// {
// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit;
// llama_sampling_penalties(smpl, &smpl->cur_p);
// if (!smpl->params.penalize_nl) {
// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) {
// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) {
// smpl->cur_p.data[idx].logit = nl_logit;
// break;
// }
// }
// }
// }
//}
//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) {
// return &smpl->cur_p;
//}
//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_softmax_impl(candidates);
//}
//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
//}
//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
//}
//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
//}
//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
//}
//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
//}
//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// if (smpl->params.dynatemp_range > 0) {
// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent);
// } else {
// llama_sampling_temp_impl(candidates, smpl->params.temp);
// }
//}
//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_grammar_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// if (smpl->grammar) {
// llama_sampling_grammar_impl(candidates, *smpl->grammar);
// smpl->n_grammar++;
// }
//}
//void llama_sampling_penalties(
// struct llama_sampling * smpl,
// llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());
// const float penalty_repeat = smpl->params.penalty_repeat;
// const float penalty_freq = smpl->params.penalty_freq;
// const float penalty_present = smpl->params.penalty_present;
// if ((penalty_last_n == 0) ||
// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
// return;
// }
// // Create a frequency map to count occurrences of each token in last_tokens
// // TODO: move to sampling state and avoid reallocation
// llama_token_cnt token_count;
// for (size_t i = 0; i < penalty_last_n; ++i) {
// token_count[smpl->prev.rat(i)]++;
// }
// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present);
//}
//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// const auto type = smpl->params.mirostat;
// llama_token res;
// if (type == 1) {
// res = llama_sampling_sample_mirostat_impl(candidates,
// smpl->rng,
// smpl->params.mirostat_tau,
// smpl->params.mirostat_eta,
// 100,
// smpl->vocab.n_vocab,
// smpl->mirostat_mu);
// } else if (type == 2) {
// res = llama_sampling_sample_mirostat_v2_impl(candidates,
// smpl->rng,
// smpl->params.mirostat_tau,
// smpl->params.mirostat_eta,
// smpl->mirostat_mu);
// } else {
// GGML_ABORT("invalid mirostat type: %d", type);
// }
// smpl->n_sample++;
// return res;
//}
//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// auto res = llama_sampling_sample_greedy_impl(candidates);
// smpl->n_sample++;
// return res;
//}
//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
// smpl->n_sample++;
// return res;
//}
//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
// time_meas tm(smpl->t_sample_us);
// if (candidates == nullptr) {
// candidates = &smpl->cur_p;
// }
// const auto & params = smpl->params;
// const float temp = params.temp;
// const int mirostat = params.mirostat;
// auto & cur_p = candidates;
// llama_token res = 0;
// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
// // greedy sampling, with probs
// llama_sampling_softmax_impl(cur_p);
// res = cur_p->data[0].id;
// } else if (temp == 0.0f) {
// // greedy sampling, no probs
// res = llama_sampling_sample_greedy(smpl, cur_p);
// } else {
// if (mirostat != 0) {
// llama_sampling_temp(smpl, cur_p);
// res = llama_sampling_sample_mirostat(smpl, cur_p);
// } else {
// for (const auto & sampler : smpl->samplers) {
// switch (sampler) {
// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
// default : break;
// }
// }
// res = llama_sampling_sample_dist(smpl, cur_p);
// //{
// // const int n_top = 10;
// // LOG("top %d candidates:\n", n_top);
// // for (int i = 0; i < n_top; i++) {
// // const llama_token id = cur_p.data[i].id;
// // (void)id; // To avoid a warning that id is unused when logging is disabled.
// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// // }
// //}
// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
// }
// }
// smpl->n_sample++;
// return res;
//}
//void llama_sampling_accept(
// struct llama_sampling * smpl,
// llama_token token,
// bool apply_grammar) {
// time_meas tm(smpl->t_accept_us);
// llama_sampling_accept_impl(*smpl, token, apply_grammar);
// smpl->n_accept++;
//}
//int llama_sampling_n_prev(const struct llama_sampling * smpl) {
// return llama_sampling_n_prev_impl(*smpl);
//}
//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
// return llama_sampling_prev_impl(*smpl, ith);
//}
//llama_token llama_sampling_last(const struct llama_sampling * smpl) {
// return llama_sampling_prev_impl(*smpl, 0);
//}
//
// sampling v2
//
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) {
return llama_constraint_init_top_k_impl(k, min_keep);
}
@ -21070,6 +20730,10 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_sample_us);
if (candidates == nullptr) {
candidates = &smpl->cur_p;
}
llama_sampler_apply_impl(*smpl, candidates);
}

View file

@ -30,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sampling_softmax_impl(&candidates_p);
llama_constraint_softmax_impl(&candidates_p);
DUMP(&candidates_p);
llama_sampling_top_k_impl(&candidates_p, k, 1);
llama_constraint_top_k_impl(&candidates_p, k, 1);
DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -52,9 +52,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sampling_softmax_impl(&candidates_p);
llama_constraint_softmax_impl(&candidates_p);
DUMP(&candidates_p);
llama_sampling_top_p_impl(&candidates_p, p, 1);
llama_constraint_top_p_impl(&candidates_p, p, 1);
DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -75,7 +75,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
llama_sampling_tail_free_impl(&candidates_p, z, 1);
llama_constraint_tail_free_impl(&candidates_p, z, 1);
DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -96,9 +96,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
llama_sampling_min_p_impl(&candidates_p, p, 1);
llama_constraint_min_p_impl(&candidates_p, p, 1);
DUMP(&candidates_p);
llama_sampling_softmax_impl(&candidates_p);
llama_constraint_softmax_impl(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) {
@ -118,7 +118,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
llama_sampling_typical_impl(&candidates_p, p, 1);
llama_constraint_typical_impl(&candidates_p, p, 1);
DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -148,10 +148,10 @@ static void test_penalties(
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sampling_softmax_impl(&candidates_p);
llama_constraint_softmax_impl(&candidates_p);
DUMP(&candidates_p);
llama_sampling_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence);
llama_sampling_softmax_impl(&candidates_p);
llama_constraint_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence);
llama_constraint_softmax_impl(&candidates_p);
DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -176,16 +176,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
for (auto s : samplers_sequence) {
switch (s){
case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break;
case 'k': llama_constraint_top_k_impl(&candidates_p, top_k, 1); break;
case 'f': GGML_ABORT("tail_free test not implemented");
case 'y': GGML_ABORT("typical test not implemented");
case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break;
case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break;
case 'p': llama_constraint_top_p_impl(&candidates_p, top_p, 1); break;
case 'm': llama_constraint_min_p_impl(&candidates_p, min_p, 1); break;
case 't': GGML_ABORT("temperature test not implemented");
default : GGML_ABORT("Unknown sampler");
}
llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests
llama_constraint_softmax_impl(&candidates_p); // make sure tokens are sorted for tests
const int size = candidates_p.size;