cont : use new API in examples
ggml-ci
This commit is contained in:
parent
437376e708
commit
a0b91214b4
21 changed files with 387 additions and 809 deletions
|
@ -2,6 +2,16 @@
|
|||
|
||||
#include "common.h"
|
||||
|
||||
struct gpt_sampler {
|
||||
gpt_sampler_params params;
|
||||
|
||||
struct llama_constraint * bias;
|
||||
struct llama_constraint * pnlt;
|
||||
struct llama_constraint * grmr;
|
||||
|
||||
struct llama_sampler * smpl;
|
||||
};
|
||||
|
||||
std::string gpt_sampler_params::print_all() const {
|
||||
char result[1024];
|
||||
|
||||
|
@ -33,8 +43,6 @@ std::string gpt_sampler_params::print_constraints() const {
|
|||
}
|
||||
|
||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
|
||||
gpt_sampler * result = new gpt_sampler();
|
||||
|
||||
llama_sampler_params lparams = llama_sampler_default_params();
|
||||
|
||||
lparams.seed = params.seed;
|
||||
|
@ -43,21 +51,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
lparams.mirostat_tau = params.mirostat_tau;
|
||||
lparams.mirostat_eta = params.mirostat_eta;
|
||||
|
||||
result->smpl = llama_sampler_init(model, lparams);
|
||||
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias(
|
||||
model,
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties(
|
||||
model,
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos));
|
||||
auto * result = new gpt_sampler {
|
||||
.params = params,
|
||||
.bias = llama_constraint_init_logit_bias(
|
||||
model,
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()),
|
||||
.pnlt = llama_constraint_init_penalties(
|
||||
model,
|
||||
params.penalty_last_n,
|
||||
params.penalty_repeat,
|
||||
params.penalty_freq,
|
||||
params.penalty_present,
|
||||
params.penalize_nl,
|
||||
params.ignore_eos),
|
||||
.grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"),
|
||||
.smpl = llama_sampler_init(model, lparams)
|
||||
};
|
||||
|
||||
for (const auto & cnstr : params.constraints) {
|
||||
switch (cnstr) {
|
||||
|
@ -84,14 +94,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
}
|
||||
}
|
||||
|
||||
result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_constraint_free(gsmpl->bias);
|
||||
llama_constraint_free(gsmpl->pnlt);
|
||||
llama_constraint_free(gsmpl->grmr);
|
||||
|
||||
llama_sampler_free(gsmpl->smpl);
|
||||
|
||||
delete gsmpl;
|
||||
|
@ -121,18 +132,28 @@ void gpt_sampler_reset (struct gpt_sampler * gsmpl) {
|
|||
llama_sampler_reset(gsmpl->smpl);
|
||||
}
|
||||
|
||||
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) {
|
||||
llama_sampler_set_logits(gsmpl->smpl, logits);
|
||||
}
|
||||
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
|
||||
return llama_sampler_get_candidates(gsmpl->smpl);
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
|
||||
return llama_sampler_last(gsmpl->smpl);
|
||||
}
|
||||
|
||||
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) {
|
||||
llama_print_timings(ctx, gsmpl->smpl);
|
||||
}
|
||||
|
||||
static llama_token gpt_sampler_sample(
|
||||
struct llama_sampler * smpl,
|
||||
struct llama_token_data_array * cur_p,
|
||||
float temp,
|
||||
int mirostat,
|
||||
int n_probs) {
|
||||
GGML_ASSERT(cur_p != nullptr && "candidates array must be provided");
|
||||
|
||||
llama_token res = 0;
|
||||
|
||||
if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) {
|
||||
|
@ -142,6 +163,7 @@ static llama_token gpt_sampler_sample(
|
|||
// greedy sampling, no probs
|
||||
res = llama_sampler_sample_greedy(smpl, cur_p, false);
|
||||
} else {
|
||||
// apply all sampling constraints and then sample
|
||||
llama_sampler_apply(smpl, cur_p);
|
||||
|
||||
if (mirostat != 0) {
|
||||
|
@ -167,42 +189,62 @@ static llama_token gpt_sampler_sample(
|
|||
return res;
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_sample(
|
||||
struct gpt_sampler * gsmpl,
|
||||
struct llama_context * ctx,
|
||||
int idx) {
|
||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||
const auto & params = gsmpl->params;
|
||||
|
||||
auto & bias = gsmpl->bias;
|
||||
auto & pnlt = gsmpl->pnlt;
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & smpl = gsmpl->smpl;
|
||||
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
|
||||
auto * cur_p = llama_sampler_get_candidates(smpl);
|
||||
|
||||
// first, sample the token without any grammar constraints
|
||||
const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
|
||||
|
||||
// create an array with a single token data element for the sampled id
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||
|
||||
llama_constraint_apply(grmr, &single_token_data_array);
|
||||
|
||||
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// if the token is not valid, sample again, after applying the grammar constraints
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
|
||||
llama_constraint_apply(bias, cur_p);
|
||||
llama_constraint_apply(pnlt, cur_p);
|
||||
|
||||
// first, sample the token without any grammar constraints
|
||||
const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs);
|
||||
|
||||
// check if it the sampled token fits the grammar
|
||||
{
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||
|
||||
llama_constraint_apply(grmr, &single_token_data_array);
|
||||
|
||||
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// if the token is not valid, sample again, first apply the grammar constraints and then sample
|
||||
llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
||||
|
||||
llama_constraint_apply(bias, cur_p);
|
||||
llama_constraint_apply(pnlt, cur_p);
|
||||
llama_constraint_apply(grmr, cur_p);
|
||||
|
||||
return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs);
|
||||
}
|
||||
|
||||
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) {
|
||||
GGML_ASSERT(candidates != nullptr);
|
||||
|
||||
llama_constraint_apply(gsmpl->grmr, candidates);
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) {
|
||||
return llama_sampler_sample_dist(gsmpl->smpl, candidates);
|
||||
}
|
||||
|
||||
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs) {
|
||||
return llama_sampler_sample_greedy(gsmpl->smpl, candidates, probs);
|
||||
}
|
||||
|
||||
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
|
||||
auto & smpl = gsmpl->smpl;
|
||||
|
||||
|
|
|
@ -60,15 +60,12 @@ struct gpt_sampler_params {
|
|||
std::string print_constraints() const;
|
||||
};
|
||||
|
||||
struct gpt_sampler {
|
||||
gpt_sampler_params params;
|
||||
|
||||
struct llama_constraint * grmr = nullptr;
|
||||
|
||||
struct llama_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
// llama_sampler API overload
|
||||
// gpt_sampler extends llama_sampler with additional functionality:
|
||||
//
|
||||
// - grammar support
|
||||
// - custom sampler logic based on the paramerters
|
||||
//
|
||||
struct gpt_sampler;
|
||||
|
||||
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
|
||||
|
||||
|
@ -79,8 +76,14 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl);
|
|||
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar);
|
||||
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
||||
|
||||
void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits);
|
||||
|
||||
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
||||
|
||||
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
||||
|
||||
void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl);
|
||||
|
||||
// common sampling implementation:
|
||||
//
|
||||
// - set logits
|
||||
|
@ -88,10 +91,12 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
|||
// - check if the token fits the grammar (if any)
|
||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||
//
|
||||
llama_token gpt_sampler_sample(
|
||||
struct gpt_sampler * gsmpl,
|
||||
struct llama_context * ctx,
|
||||
int idx);
|
||||
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||
|
||||
void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates);
|
||||
|
||||
llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * candidates);
|
||||
llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs);
|
||||
|
||||
// helpers
|
||||
|
||||
|
|
|
@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
auto sparams = llama_sampling_default_params();
|
||||
auto sparams = llama_sampler_default_params();
|
||||
|
||||
sparams.seed = params.sparams.seed;
|
||||
sparams.top_k = 40;
|
||||
sparams.top_p = 0.9f;
|
||||
sparams.temp = 0.4f;
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
||||
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
|
@ -174,15 +175,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
llama_sampling_top_k(smpl, nullptr);
|
||||
llama_sampling_top_p(smpl, nullptr);
|
||||
llama_sampling_temp (smpl, nullptr);
|
||||
const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr);
|
||||
|
||||
const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr);
|
||||
|
||||
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||
//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -246,7 +243,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
return result;
|
||||
}
|
||||
|
||||
static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) {
|
||||
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
|
||||
std::string result;
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
@ -122,9 +122,9 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st
|
|||
|
||||
const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
llama_token token = llama_sampling_sample_greedy(smpl, nullptr);
|
||||
llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false);
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
}
|
||||
|
@ -171,7 +171,7 @@ int main(int argc, char * argv[]) {
|
|||
// create generation context
|
||||
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
|
||||
|
||||
// ### Embedding/Representation ###
|
||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||
|
@ -212,7 +212,7 @@ int main(int argc, char * argv[]) {
|
|||
std::string response = generate(ctx, smpl, prompt, true);
|
||||
}
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
llama_backend_free();
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static llama_sampling ** g_smpl;
|
||||
static gpt_sampler ** g_smpl;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
|
@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx, *g_smpl);
|
||||
gpt_print_timings(*g_ctx, *g_smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
llama_sampling * smpl = nullptr;
|
||||
gpt_sampler * smpl = nullptr;
|
||||
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
|
@ -345,7 +345,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
smpl = llama_sampling_init(model, sparams);
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
@ -417,9 +417,9 @@ int main(int argc, char ** argv) {
|
|||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
|
||||
|
||||
|
@ -440,7 +440,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(smpl, embd_inp[n_consumed], false);
|
||||
gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -472,7 +472,7 @@ int main(int argc, char ** argv) {
|
|||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// deal with eot token in infill mode
|
||||
if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||
if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||
if (is_interacting && !params.interactive_first) {
|
||||
// print an eot token
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
||||
|
@ -538,7 +538,7 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = false;
|
||||
}
|
||||
// deal with end of generation tokens in interactive mode
|
||||
else if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
|
||||
else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (n_past > 0) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(smpl);
|
||||
gpt_sampler_reset(smpl);
|
||||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
|
@ -634,13 +634,13 @@ int main(int argc, char ** argv) {
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx, smpl);
|
||||
gpt_print_timings(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
llama_backend_free();
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
|
|
|
@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
|
|||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling * smpl,
|
||||
static const char * sample(struct gpt_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
|
@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
|
|
@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
|||
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling * smpl,
|
||||
static const char * sample(struct gpt_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
|
|||
return ctx_llava;
|
||||
}
|
||||
|
||||
static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||
static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||
std::string user_prompt = prompt;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (!is_first) {
|
||||
|
@ -238,11 +238,11 @@ static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
|
||||
return smpl;
|
||||
}
|
||||
|
||||
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){
|
||||
static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
|
||||
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
return tmp;
|
||||
|
@ -296,7 +296,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stdout);
|
||||
}
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
}else {
|
||||
while (true) {
|
||||
LOG_TEE("<user>");
|
||||
|
@ -315,7 +315,7 @@ int main(int argc, char ** argv) {
|
|||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||
fflush(stdout);
|
||||
}
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
|
|
|
@ -117,7 +117,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
@ -158,9 +158,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// sample first token
|
||||
{
|
||||
id = llama_sampling_sample(smpl, ctx, 0);
|
||||
id = gpt_sampler_sample(smpl, ctx, 0);
|
||||
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
{
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
@ -283,9 +283,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// sample the next token
|
||||
id = llama_sampling_sample(smpl, ctx, i_batch);
|
||||
id = gpt_sampler_sample(smpl, ctx, i_batch);
|
||||
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
// print
|
||||
{
|
||||
|
@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
|
|||
if (v == 0) {
|
||||
// sample from the last level
|
||||
for (int i = 0; i < W; i++) {
|
||||
tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < W; i++) {
|
||||
|
@ -467,10 +467,11 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
|
||||
llama_print_timings(ctx, smpl);
|
||||
gpt_print_timings(ctx, smpl);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
llama_sampling_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling * smpl = llama_sampling_init(model, params.sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
@ -128,9 +128,9 @@ int main(int argc, char ** argv){
|
|||
int i_dft = 0;
|
||||
while (true) {
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(smpl, ctx, i_dft);
|
||||
llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
|
||||
|
||||
llama_sampling_accept(smpl, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
||||
|
@ -239,9 +239,10 @@ int main(int argc, char ** argv){
|
|||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx, smpl);
|
||||
gpt_print_timings(ctx, smpl);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_batch_free(batch_tgt);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -106,7 +106,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx, (*g_smpl)->smpl);
|
||||
gpt_print_timings(*g_ctx, *g_smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -928,7 +928,7 @@ int main(int argc, char ** argv) {
|
|||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
}
|
||||
|
||||
llama_print_timings(ctx, smpl->smpl);
|
||||
gpt_print_timings(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
|
|
@ -51,7 +51,7 @@ static std::vector<std::string> k_prompts = {
|
|||
struct client {
|
||||
~client() {
|
||||
if (smpl) {
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ struct client {
|
|||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
struct llama_sampling * smpl = nullptr;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.smpl = llama_sampling_init(model, params.sparams);
|
||||
client.smpl = gpt_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
|
|||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
llama_sampling_reset(client.smpl);
|
||||
gpt_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
|
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.smpl, id, true);
|
||||
gpt_sampler_accept(client.smpl, id, true);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
|
|
@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> tokens_list;
|
||||
|
@ -218,10 +218,10 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
@ -262,9 +262,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling_params sparams = llama_sampling_default_params();
|
||||
llama_sampler_params sparams = llama_sampler_default_params();
|
||||
sparams.seed = params.sparams.seed;
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, sparams);
|
||||
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
||||
|
||||
// tokenize prompt
|
||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||
|
@ -71,9 +71,9 @@ int main(int argc, char ** argv) {
|
|||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
const auto * logits = llama_get_logits(ctx);
|
||||
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
|
||||
auto next_token = llama_sampler_sample_dist(smpl, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
|
|||
// make new context
|
||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl2 = llama_sampling_init(model, sparams);
|
||||
llama_sampler * smpl2 = llama_sampler_init(model, sparams);
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
|
@ -128,9 +128,9 @@ int main(int argc, char ** argv) {
|
|||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
const auto * logits = llama_get_logits(ctx2);
|
||||
|
||||
llama_sampling_set_logits(smpl2, logits);
|
||||
llama_sampler_set_logits(smpl2, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
|
||||
auto next_token = llama_sampler_sample_dist(smpl2, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
@ -157,7 +157,7 @@ int main(int argc, char ** argv) {
|
|||
// make new context
|
||||
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampling * smpl3 = llama_sampling_init(model, sparams);
|
||||
llama_sampler * smpl3 = llama_sampler_init(model, sparams);
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
|
@ -217,9 +217,9 @@ int main(int argc, char ** argv) {
|
|||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
const auto * logits = llama_get_logits(ctx3);
|
||||
|
||||
llama_sampling_set_logits(smpl3, logits);
|
||||
llama_sampler_set_logits(smpl3, logits);
|
||||
|
||||
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
|
||||
auto next_token = llama_sampler_sample_dist(smpl3, nullptr);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
|
@ -236,9 +236,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("\n");
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_sampling_free(smpl2);
|
||||
llama_sampling_free(smpl3);
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
|
||||
llama_free(ctx3);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -170,10 +170,10 @@ struct server_slot {
|
|||
// sampling
|
||||
json json_schema;
|
||||
|
||||
struct gpt_sampling_params sparams;
|
||||
struct gpt_sampler_params sparams;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
|
||||
llama_token sampled;
|
||||
llama_sampling * smpl = nullptr;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
|
@ -653,7 +653,7 @@ struct server_context {
|
|||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.smpl != nullptr) {
|
||||
llama_sampling_free(slot.smpl);
|
||||
gpt_sampler_free(slot.smpl);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1027,26 +1027,26 @@ struct server_context {
|
|||
}
|
||||
|
||||
{
|
||||
const auto & samplers = data.find("samplers");
|
||||
if (samplers != data.end() && samplers->is_array()) {
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto & sampler_name : *samplers) {
|
||||
if (sampler_name.is_string()) {
|
||||
sampler_names.emplace_back(sampler_name);
|
||||
const auto & constraints = data.find("samplers");
|
||||
if (constraints != data.end() && constraints->is_array()) {
|
||||
std::vector<std::string> constraint_names;
|
||||
for (const auto & name : *constraints) {
|
||||
if (name.is_string()) {
|
||||
constraint_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false);
|
||||
slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false);
|
||||
} else {
|
||||
slot.sparams.samplers = default_sparams.samplers;
|
||||
slot.sparams.constraints = default_sparams.constraints;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.smpl != nullptr) {
|
||||
llama_sampling_free(slot.smpl);
|
||||
gpt_sampler_free(slot.smpl);
|
||||
}
|
||||
|
||||
slot.smpl = llama_sampling_init(model, slot.sparams);
|
||||
slot.smpl = gpt_sampler_init(model, slot.sparams);
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
@ -1253,10 +1253,10 @@ struct server_context {
|
|||
}
|
||||
|
||||
json get_formated_generation(const server_slot & slot) const {
|
||||
std::vector<std::string> samplers;
|
||||
samplers.reserve(slot.sparams.samplers.size());
|
||||
for (const auto & sampler : slot.sparams.samplers) {
|
||||
samplers.emplace_back(llama_sampling_type_to_str(sampler));
|
||||
std::vector<std::string> constraints;
|
||||
constraints.reserve(slot.sparams.constraints.size());
|
||||
for (const auto & constraint : slot.sparams.constraints) {
|
||||
constraints.emplace_back(gpt_constraint_type_to_str(constraint));
|
||||
}
|
||||
|
||||
return json {
|
||||
|
@ -1290,7 +1290,7 @@ struct server_context {
|
|||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers},
|
||||
{"samplers", constraints},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2084,7 +2084,7 @@ struct server_context {
|
|||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
llama_sampling_reset(slot.smpl);
|
||||
gpt_sampler_reset(slot.smpl);
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
|
@ -2097,7 +2097,7 @@ struct server_context {
|
|||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2150,7 +2150,7 @@ struct server_context {
|
|||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
llama_sampling_reset(slot.smpl);
|
||||
gpt_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
// remove the non-common part from the cache
|
||||
|
@ -2332,9 +2332,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
|
||||
llama_sampling_accept(slot.smpl, id, true);
|
||||
gpt_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
|
@ -2345,7 +2345,7 @@ struct server_context {
|
|||
|
||||
result.tok = id;
|
||||
|
||||
const auto * cur_p = llama_sampling_get_candidates(slot.smpl);
|
||||
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
|
||||
|
||||
// TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643
|
||||
// fix if necessary
|
||||
|
|
|
@ -55,7 +55,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params());
|
||||
llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params());
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
|
@ -114,10 +114,10 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
|
||||
llama_sampling_set_logits(smpl, logits);
|
||||
llama_sampler_set_logits(smpl, logits);
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
|
||||
const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -159,8 +159,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ struct seq_draft {
|
|||
std::vector<llama_token> tokens;
|
||||
std::vector<std::vector<llama_token_data>> dists;
|
||||
|
||||
struct llama_sampling * smpl;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
@ -180,14 +180,14 @@ int main(int argc, char ** argv) {
|
|||
bool has_eos = false;
|
||||
|
||||
// target model sampling context (reuse the llama_context's sampling instance)
|
||||
struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
// allocate llama_sampling for each draft sequence
|
||||
drafts[s].smpl = llama_sampling_init(model_dft, params.sparams);
|
||||
// allocate gpt_sampler for each draft sequence
|
||||
drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
@ -229,13 +229,14 @@ int main(int argc, char ** argv) {
|
|||
bool accept = false;
|
||||
if (params.sparams.temp > 0) {
|
||||
// stochastic verification
|
||||
const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]));
|
||||
gpt_sampler_set_logits(smpl, logits);
|
||||
|
||||
auto & dist_tgt = *llama_sampling_get_candidates(smpl);
|
||||
auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
|
||||
|
||||
llama_sampling_grammar(smpl, &dist_tgt);
|
||||
llama_sampling_softmax(smpl, &dist_tgt);
|
||||
gpt_sampler_apply_grammar(smpl, &dist_tgt);
|
||||
gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax
|
||||
|
||||
float p_tgt = 0.0f;
|
||||
float p_dft = 0.0f;
|
||||
|
@ -280,7 +281,7 @@ int main(int argc, char ** argv) {
|
|||
accept = true;
|
||||
token_id = drafts[s].tokens[i_dft];
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
llama_sampling_accept(smpl, token_id, true);
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
|
||||
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||
break;
|
||||
|
@ -334,8 +335,8 @@ int main(int argc, char ** argv) {
|
|||
// all drafted tokens were rejected
|
||||
// sample from the target model
|
||||
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||
token_id = llama_sampling_sample_dist(smpl, &dist_tgt);
|
||||
llama_sampling_accept(smpl, token_id, true);
|
||||
token_id = gpt_sampler_sample_dist(smpl, &dist_tgt);
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
}
|
||||
|
||||
|
@ -344,9 +345,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// sample from the target model
|
||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(smpl, token_id, true);
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
|
||||
|
||||
|
@ -436,7 +437,10 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
llama_sampling_cp(smpl, drafts[0].smpl);
|
||||
if (drafts[0].smpl) {
|
||||
gpt_sampler_free(drafts[0].smpl);
|
||||
}
|
||||
drafts[0].smpl = gpt_sampler_cp(smpl);
|
||||
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
|
@ -465,9 +469,9 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
||||
gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
||||
|
||||
const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl);
|
||||
const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
|
||||
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
|
@ -505,7 +509,11 @@ int main(int argc, char ** argv) {
|
|||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl);
|
||||
if (drafts[n_seq_cur].smpl) {
|
||||
gpt_sampler_free(drafts[n_seq_cur].smpl);
|
||||
}
|
||||
drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl);
|
||||
|
||||
|
||||
sa.push_back(n_seq_cur);
|
||||
|
||||
|
@ -521,7 +529,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int s = sa[is];
|
||||
|
||||
llama_sampling_accept(drafts[s].smpl, id, true);
|
||||
gpt_sampler_accept(drafts[s].smpl, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
// save cur_p.data into drafts[s].dists
|
||||
|
@ -597,14 +605,14 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_TEE("\ndraft:\n");
|
||||
// TODO: print sampling/grammar timings for all drafts
|
||||
llama_print_timings(ctx_dft, nullptr);
|
||||
gpt_print_timings(ctx_dft, nullptr);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx_tgt, smpl);
|
||||
gpt_print_timings(ctx_tgt, smpl);
|
||||
|
||||
llama_sampling_free(smpl);
|
||||
gpt_sampler_free(smpl);
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
llama_sampling_free(drafts[s].smpl);
|
||||
gpt_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
|
|
138
include/llama.h
138
include/llama.h
|
@ -46,9 +46,6 @@
|
|||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 2
|
||||
|
||||
// TODO: remove before merge
|
||||
#define LLAMA_MAX_SAMPLERS 16
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -1001,133 +998,6 @@ extern "C" {
|
|||
|
||||
//
|
||||
// Sampling API
|
||||
// TODO: remove before merge
|
||||
//
|
||||
|
||||
// TODO: llama_model should become llama_vocab
|
||||
//LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params);
|
||||
|
||||
//LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
|
||||
|
||||
//// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
|
||||
//LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
|
||||
|
||||
//// - clear prev token
|
||||
//// - reset grammar state
|
||||
//LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
|
||||
|
||||
//// Sampling parameter mutation
|
||||
//// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
|
||||
//LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
|
||||
//LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
|
||||
|
||||
//// Set the logits from which to sample.
|
||||
//// This call initializes the internal token candidates array.
|
||||
//// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
|
||||
//LLAMA_API void llama_sampling_set_logits(
|
||||
// struct llama_sampling * smpl,
|
||||
// const float * logits);
|
||||
|
||||
///// @details Returns the current candidate tokens.
|
||||
//LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
|
||||
// struct llama_sampling * smpl);
|
||||
|
||||
//// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
|
||||
//// Each function can accept an array of token candidates. If the candidates are not provided, the internal
|
||||
//// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
|
||||
|
||||
///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
//LLAMA_API void llama_sampling_softmax(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
//LLAMA_API void llama_sampling_top_k(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
//LLAMA_API void llama_sampling_top_p(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
//LLAMA_API void llama_sampling_min_p(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
//LLAMA_API void llama_sampling_tail_free(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
//LLAMA_API void llama_sampling_typical(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Apply temperature and entropy
|
||||
//LLAMA_API void llama_sampling_temp(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Apply constraints from grammar
|
||||
//LLAMA_API void llama_sampling_grammar(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
//LLAMA_API void llama_sampling_penalties(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
//LLAMA_API llama_token llama_sampling_sample_mirostat(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Selects the token with the highest probability.
|
||||
///// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
|
||||
//LLAMA_API llama_token llama_sampling_sample_greedy(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Randomly selects a token from the candidates based on their probability distribution.
|
||||
//LLAMA_API llama_token llama_sampling_sample_dist(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
|
||||
//LLAMA_API llama_token llama_sampling_sample(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates);
|
||||
|
||||
///// @details Accepts the sampled token into the sampling context.
|
||||
///// - adds it to "prev" tokens
|
||||
///// - updates the grammar state (if apply_grammar is true)
|
||||
//LLAMA_API void llama_sampling_accept(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token token,
|
||||
// bool apply_grammar);
|
||||
|
||||
///// @details Get the number of accepted tokens so far (max of n_prev)
|
||||
//LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
|
||||
|
||||
///// @details Get the ith accepted token
|
||||
///// @param ith [0, n_prev), ith == 0 is the last accepted token.
|
||||
///// returns LLAMA_TOKEN_NULL if ith is out of bounds
|
||||
//LLAMA_API llama_token llama_sampling_prev(
|
||||
// const struct llama_sampling * smpl,
|
||||
// int32_t ith);
|
||||
|
||||
///// @details Get the last accepted token
|
||||
///// Same as llama_sampling_prev(smpl, 0)
|
||||
///// returns LLAMA_TOKEN_NULL if there are no accepted tokens
|
||||
//LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
|
||||
|
||||
//
|
||||
// Sampling v2 API
|
||||
//
|
||||
// - Constraints
|
||||
// The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their
|
||||
|
@ -1203,7 +1073,7 @@ extern "C" {
|
|||
|
||||
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
|
||||
|
||||
// do not call if used with llama_sampler_add_constraint
|
||||
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
|
||||
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||
|
||||
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
|
||||
|
@ -1221,11 +1091,7 @@ extern "C" {
|
|||
|
||||
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
|
||||
|
||||
|
||||
// TODO: should this take ownership so the user does not need to call llama_constraint_free
|
||||
// or should just make a reference to the constraint so that it can be reused in multiple llama_sampler?
|
||||
//
|
||||
// seems better to take the ownership, otherwise the copying of the sampler will be more complicated
|
||||
// important: takes ownership of the constraint object and will free it in llama_sampler_free
|
||||
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||
|
||||
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||
|
|
|
@ -421,113 +421,8 @@ void llama_constraint_penalties_impl(
|
|||
candidates->sorted = false;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_constraint_top_k_impl(candidates, int(k), 1);
|
||||
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
mu = mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > mu;
|
||||
}));
|
||||
|
||||
if (candidates->size == 0) {
|
||||
candidates->size = 1;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
mu = mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) {
|
||||
if (probs) {
|
||||
// if probs are needed, we apply softmax to get the probabilities
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// the candidates are sorted, so we can just return the first one
|
||||
return candidates->data[0].id;
|
||||
}
|
||||
|
||||
// return the token with the highest logit
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
|
||||
const int idx = dist(rng);
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// sampling v2
|
||||
// sampling
|
||||
//
|
||||
|
||||
// constraints
|
||||
|
@ -1172,3 +1067,107 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) {
|
|||
return smpl.prev.size();
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Estimate s_hat using the most probable m tokens
|
||||
float s_hat = 0.0;
|
||||
float sum_ti_bi = 0.0;
|
||||
float sum_ti_sq = 0.0;
|
||||
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||
float t_i = logf(float(i + 2) / float(i + 1));
|
||||
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||
sum_ti_bi += t_i * b_i;
|
||||
sum_ti_sq += t_i * t_i;
|
||||
}
|
||||
s_hat = sum_ti_bi / sum_ti_sq;
|
||||
|
||||
// Compute k from the estimated s_hat and target surprise value
|
||||
float epsilon_hat = s_hat - 1;
|
||||
float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||
|
||||
// Sample the next word X using top-k sampling
|
||||
llama_constraint_top_k_impl(candidates, int(k), 1);
|
||||
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
mu = mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Truncate the words with surprise values greater than mu
|
||||
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return -log2f(candidate.p) > mu;
|
||||
}));
|
||||
|
||||
if (candidates->size == 0) {
|
||||
candidates->size = 1;
|
||||
}
|
||||
|
||||
// Normalize the probabilities of the remaining words
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// Sample the next word X from the remaining words
|
||||
llama_token X = llama_sampler_sample_dist_impl(candidates, rng);
|
||||
|
||||
// Compute error as the difference between observed surprise and target surprise value
|
||||
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||
return candidate.id == X;
|
||||
}));
|
||||
|
||||
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||
float e = observed_surprise - tau;
|
||||
|
||||
// Update mu using the learning rate and error
|
||||
mu = mu - eta * e;
|
||||
|
||||
return X;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) {
|
||||
if (probs) {
|
||||
// if probs are needed, we apply softmax to get the probabilities
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
// the candidates are sorted, so we can just return the first one
|
||||
return candidates->data[0].id;
|
||||
}
|
||||
|
||||
// return the token with the highest logit
|
||||
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit < b.logit;
|
||||
});
|
||||
|
||||
llama_token result = max_iter->id;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||
llama_constraint_softmax_impl(candidates);
|
||||
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
|
||||
const int idx = dist(rng);
|
||||
llama_token result = candidates->data[idx].id;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ struct llama_grammar;
|
|||
|
||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
||||
|
||||
// TODO: tmp exposed, until tests start using llama_constraint
|
||||
void llama_constraint_softmax_impl (struct llama_token_data_array * candidates);
|
||||
void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||
void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
|
@ -27,30 +28,6 @@ void llama_constraint_penalties_impl(
|
|||
float penalty_freq,
|
||||
float penalty_present);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
|
||||
|
||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
|
||||
|
||||
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs);
|
||||
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// sampling v2
|
||||
//
|
||||
|
||||
// constraints
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
|
||||
|
@ -128,3 +105,21 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d
|
|||
|
||||
llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith);
|
||||
int llama_sampler_n_prev_impl(const struct llama_sampler & smpl);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu);
|
||||
|
||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
|
||||
|
||||
llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs);
|
||||
llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
|
||||
|
|
344
src/llama.cpp
344
src/llama.cpp
|
@ -20609,346 +20609,6 @@ int32_t llama_chat_apply_template(
|
|||
// sampling
|
||||
//
|
||||
|
||||
//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) {
|
||||
// return llama_sampling_init_impl(model->vocab, params);
|
||||
//}
|
||||
|
||||
//void llama_sampling_free(struct llama_sampling * smpl) {
|
||||
// if (smpl == nullptr) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
// llama_sampling_free_impl(smpl);
|
||||
//}
|
||||
|
||||
//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) {
|
||||
// return llama_sampling_cp_impl(*smpl);
|
||||
//}
|
||||
|
||||
//void llama_sampling_reset(struct llama_sampling * smpl) {
|
||||
// llama_sampling_reset_impl(*smpl);
|
||||
//}
|
||||
|
||||
//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
|
||||
// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
|
||||
//}
|
||||
|
||||
//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) {
|
||||
// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias);
|
||||
//}
|
||||
|
||||
//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) {
|
||||
// const int n_vocab = smpl->vocab.n_vocab;
|
||||
|
||||
// smpl->cur.resize(n_vocab);
|
||||
|
||||
// for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
// }
|
||||
|
||||
// for (const auto & lb : smpl->logit_bias) {
|
||||
// smpl->cur[lb.token].logit += lb.bias;
|
||||
// }
|
||||
|
||||
// if (smpl->params.ignore_eos) {
|
||||
// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY;
|
||||
// }
|
||||
|
||||
// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false };
|
||||
|
||||
// // apply penalties
|
||||
// {
|
||||
// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit;
|
||||
|
||||
// llama_sampling_penalties(smpl, &smpl->cur_p);
|
||||
|
||||
// if (!smpl->params.penalize_nl) {
|
||||
// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) {
|
||||
// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) {
|
||||
// smpl->cur_p.data[idx].logit = nl_logit;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) {
|
||||
// return &smpl->cur_p;
|
||||
//}
|
||||
|
||||
//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_softmax_impl(candidates);
|
||||
//}
|
||||
|
||||
//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
|
||||
//}
|
||||
|
||||
//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
|
||||
//}
|
||||
|
||||
//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
|
||||
//}
|
||||
|
||||
//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
|
||||
//}
|
||||
|
||||
//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
|
||||
//}
|
||||
|
||||
//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// if (smpl->params.dynatemp_range > 0) {
|
||||
// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
|
||||
// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
|
||||
|
||||
// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent);
|
||||
// } else {
|
||||
// llama_sampling_temp_impl(candidates, smpl->params.temp);
|
||||
// }
|
||||
//}
|
||||
|
||||
//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_grammar_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// if (smpl->grammar) {
|
||||
// llama_sampling_grammar_impl(candidates, *smpl->grammar);
|
||||
|
||||
// smpl->n_grammar++;
|
||||
// }
|
||||
//}
|
||||
|
||||
//void llama_sampling_penalties(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());
|
||||
|
||||
// const float penalty_repeat = smpl->params.penalty_repeat;
|
||||
// const float penalty_freq = smpl->params.penalty_freq;
|
||||
// const float penalty_present = smpl->params.penalty_present;
|
||||
|
||||
// if ((penalty_last_n == 0) ||
|
||||
// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
// // Create a frequency map to count occurrences of each token in last_tokens
|
||||
// // TODO: move to sampling state and avoid reallocation
|
||||
// llama_token_cnt token_count;
|
||||
// for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||
// token_count[smpl->prev.rat(i)]++;
|
||||
// }
|
||||
|
||||
// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present);
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// const auto type = smpl->params.mirostat;
|
||||
|
||||
// llama_token res;
|
||||
|
||||
// if (type == 1) {
|
||||
// res = llama_sampling_sample_mirostat_impl(candidates,
|
||||
// smpl->rng,
|
||||
// smpl->params.mirostat_tau,
|
||||
// smpl->params.mirostat_eta,
|
||||
// 100,
|
||||
// smpl->vocab.n_vocab,
|
||||
// smpl->mirostat_mu);
|
||||
// } else if (type == 2) {
|
||||
// res = llama_sampling_sample_mirostat_v2_impl(candidates,
|
||||
// smpl->rng,
|
||||
// smpl->params.mirostat_tau,
|
||||
// smpl->params.mirostat_eta,
|
||||
// smpl->mirostat_mu);
|
||||
// } else {
|
||||
// GGML_ABORT("invalid mirostat type: %d", type);
|
||||
// }
|
||||
|
||||
// smpl->n_sample++;
|
||||
|
||||
// return res;
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// auto res = llama_sampling_sample_greedy_impl(candidates);
|
||||
|
||||
// smpl->n_sample++;
|
||||
|
||||
// return res;
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
|
||||
|
||||
// smpl->n_sample++;
|
||||
|
||||
// return res;
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||
// time_meas tm(smpl->t_sample_us);
|
||||
|
||||
// if (candidates == nullptr) {
|
||||
// candidates = &smpl->cur_p;
|
||||
// }
|
||||
|
||||
// const auto & params = smpl->params;
|
||||
|
||||
// const float temp = params.temp;
|
||||
// const int mirostat = params.mirostat;
|
||||
|
||||
// auto & cur_p = candidates;
|
||||
|
||||
// llama_token res = 0;
|
||||
|
||||
// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
|
||||
// // greedy sampling, with probs
|
||||
// llama_sampling_softmax_impl(cur_p);
|
||||
// res = cur_p->data[0].id;
|
||||
// } else if (temp == 0.0f) {
|
||||
// // greedy sampling, no probs
|
||||
// res = llama_sampling_sample_greedy(smpl, cur_p);
|
||||
// } else {
|
||||
// if (mirostat != 0) {
|
||||
// llama_sampling_temp(smpl, cur_p);
|
||||
// res = llama_sampling_sample_mirostat(smpl, cur_p);
|
||||
// } else {
|
||||
// for (const auto & sampler : smpl->samplers) {
|
||||
// switch (sampler) {
|
||||
// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break;
|
||||
// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break;
|
||||
// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break;
|
||||
// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break;
|
||||
// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break;
|
||||
// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break;
|
||||
// default : break;
|
||||
// }
|
||||
// }
|
||||
|
||||
// res = llama_sampling_sample_dist(smpl, cur_p);
|
||||
|
||||
// //{
|
||||
// // const int n_top = 10;
|
||||
// // LOG("top %d candidates:\n", n_top);
|
||||
|
||||
// // for (int i = 0; i < n_top; i++) {
|
||||
// // const llama_token id = cur_p.data[i].id;
|
||||
// // (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||
// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
|
||||
// // }
|
||||
// //}
|
||||
|
||||
// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str());
|
||||
// }
|
||||
// }
|
||||
|
||||
// smpl->n_sample++;
|
||||
|
||||
// return res;
|
||||
//}
|
||||
|
||||
//void llama_sampling_accept(
|
||||
// struct llama_sampling * smpl,
|
||||
// llama_token token,
|
||||
// bool apply_grammar) {
|
||||
// time_meas tm(smpl->t_accept_us);
|
||||
|
||||
// llama_sampling_accept_impl(*smpl, token, apply_grammar);
|
||||
|
||||
// smpl->n_accept++;
|
||||
//}
|
||||
|
||||
//int llama_sampling_n_prev(const struct llama_sampling * smpl) {
|
||||
// return llama_sampling_n_prev_impl(*smpl);
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
|
||||
// return llama_sampling_prev_impl(*smpl, ith);
|
||||
//}
|
||||
|
||||
//llama_token llama_sampling_last(const struct llama_sampling * smpl) {
|
||||
// return llama_sampling_prev_impl(*smpl, 0);
|
||||
//}
|
||||
|
||||
//
|
||||
// sampling v2
|
||||
//
|
||||
|
||||
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) {
|
||||
return llama_constraint_init_top_k_impl(k, min_keep);
|
||||
}
|
||||
|
@ -21070,6 +20730,10 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|||
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) {
|
||||
time_meas tm(smpl->t_sample_us);
|
||||
|
||||
if (candidates == nullptr) {
|
||||
candidates = &smpl->cur_p;
|
||||
}
|
||||
|
||||
llama_sampler_apply_impl(*smpl, candidates);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sampling_softmax_impl(&candidates_p);
|
||||
llama_constraint_softmax_impl(&candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_top_k_impl(&candidates_p, k, 1);
|
||||
llama_constraint_top_k_impl(&candidates_p, k, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -52,9 +52,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sampling_softmax_impl(&candidates_p);
|
||||
llama_constraint_softmax_impl(&candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_top_p_impl(&candidates_p, p, 1);
|
||||
llama_constraint_top_p_impl(&candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -75,7 +75,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_tail_free_impl(&candidates_p, z, 1);
|
||||
llama_constraint_tail_free_impl(&candidates_p, z, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -96,9 +96,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_min_p_impl(&candidates_p, p, 1);
|
||||
llama_constraint_min_p_impl(&candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_softmax_impl(&candidates_p);
|
||||
llama_constraint_softmax_impl(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
|
@ -118,7 +118,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_typical_impl(&candidates_p, p, 1);
|
||||
llama_constraint_typical_impl(&candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -148,10 +148,10 @@ static void test_penalties(
|
|||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sampling_softmax_impl(&candidates_p);
|
||||
llama_constraint_softmax_impl(&candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sampling_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sampling_softmax_impl(&candidates_p);
|
||||
llama_constraint_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_constraint_softmax_impl(&candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
|
@ -176,16 +176,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
|||
|
||||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break;
|
||||
case 'k': llama_constraint_top_k_impl(&candidates_p, top_k, 1); break;
|
||||
case 'f': GGML_ABORT("tail_free test not implemented");
|
||||
case 'y': GGML_ABORT("typical test not implemented");
|
||||
case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break;
|
||||
case 'p': llama_constraint_top_p_impl(&candidates_p, top_p, 1); break;
|
||||
case 'm': llama_constraint_min_p_impl(&candidates_p, min_p, 1); break;
|
||||
case 't': GGML_ABORT("temperature test not implemented");
|
||||
default : GGML_ABORT("Unknown sampler");
|
||||
}
|
||||
|
||||
llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests
|
||||
llama_constraint_softmax_impl(&candidates_p); // make sure tokens are sorted for tests
|
||||
|
||||
const int size = candidates_p.size;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue