llama : use struct llama_sampling in the sampling API

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-23 17:35:28 +03:00
parent f866cb9342
commit dbf85440c7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
30 changed files with 437 additions and 395 deletions

View file

@ -2125,7 +2125,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_reset_timings(lctx); llama_reset_timings(lctx, nullptr, nullptr);
} }
return std::make_tuple(model, lctx); return std::make_tuple(model, lctx);

View file

@ -2,12 +2,11 @@
#include <random> #include <random>
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) { struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
result->params = params; result->params = params;
result->seq_id = seq_id; result->smpl = smpl;
result->ctx = ctx;
result->grammar = nullptr; result->grammar = nullptr;
// if there is a grammar, parse it // if there is a grammar, parse it
@ -43,7 +42,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->n_valid = 0; result->n_valid = 0;
llama_sampling_set_rng_seed(result, params.seed); llama_sampling_set_rng_seed(result->smpl, params.seed);
return result; return result;
} }
@ -79,13 +78,6 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->n_valid = 0; ctx->n_valid = 0;
} }
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = std::random_device{}();
}
llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
}
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) { if (dst->grammar) {
llama_grammar_free(dst->grammar); llama_grammar_free(dst->grammar);
@ -230,10 +222,13 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
// no reasons to expose this function in header // no reasons to expose this function in header
static void sampler_queue( static void sampler_queue(
struct llama_context * ctx_main, struct llama_sampling_context * ctx_sampling,
const llama_sampling_params & params,
llama_token_data_array & cur_p, llama_token_data_array & cur_p,
size_t min_keep) { size_t min_keep) {
llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params;
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent; const float dynatemp_exponent = params.dynatemp_exponent;
@ -246,18 +241,18 @@ static void sampler_queue(
for (auto sampler_type : samplers_sequence) { for (auto sampler_type : samplers_sequence) {
switch (sampler_type) { switch (sampler_type) {
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break;
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break;
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break;
case llama_sampler_type::TEMPERATURE: case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) { if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
} else { } else {
llama_sample_temp(ctx_main, &cur_p, temp); llama_sampling_temp(smpl, &cur_p, temp);
} }
break; break;
default : break; default : break;
@ -271,6 +266,8 @@ static llama_token llama_sampling_sample_impl(
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx, const int idx,
bool is_resampling) { bool is_resampling) {
llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const float temp = params.temp; const float temp = params.temp;
@ -287,26 +284,26 @@ static llama_token llama_sampling_sample_impl(
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p); llama_sampling_softmax(smpl, &cur_p);
id = cur_p.data[0].id; id = cur_p.data[0].id;
} else if (temp == 0.0) { } else if (temp == 0.0) {
// greedy sampling, no probs // greedy sampling, no probs
id = llama_sample_token_greedy(ctx_main, &cur_p); id = llama_sampling_sample_greedy(smpl, &cur_p);
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temp(ctx_main, &cur_p, temp); llama_sampling_temp(smpl, &cur_p, temp);
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
llama_sample_temp(ctx_main, &cur_p, temp); llama_sampling_temp(smpl, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else { } else {
// temperature sampling // temperature sampling
size_t min_keep = std::max(1, params.min_keep); size_t min_keep = std::max(1, params.min_keep);
sampler_queue(ctx_main, params, cur_p, min_keep); sampler_queue(ctx_sampling, cur_p, min_keep);
id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id); id = llama_sampling_sample(smpl, &cur_p);
//{ //{
// const int n_top = 10; // const int n_top = 10;
@ -315,11 +312,11 @@ static llama_token llama_sampling_sample_impl(
// for (int i = 0; i < n_top; i++) { // for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id; // const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled. // (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
// } // }
//} //}
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str());
} }
} }
@ -360,6 +357,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
const int idx, const int idx,
bool apply_grammar, bool apply_grammar,
std::vector<float> * original_logits) { std::vector<float> * original_logits) {
llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -390,7 +389,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_cfg) { if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
} }
cur.resize(n_vocab); cur.resize(n_vocab);
@ -407,7 +406,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (penalty_tokens_used_size) { if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p, llama_sampling_repetition_penalties(smpl, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
@ -445,7 +444,7 @@ llama_token_data_array llama_sampling_prepare(
const int idx, const int idx,
bool apply_grammar, bool apply_grammar,
std::vector<float> * original_logits) { std::vector<float> * original_logits) {
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); return llama_sampling_prepare_impl(ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
} }
void llama_sampling_accept( void llama_sampling_accept(

View file

@ -70,12 +70,10 @@ struct llama_sampling_context {
// parameters that will be used for sampling // parameters that will be used for sampling
llama_sampling_params params; llama_sampling_params params;
llama_seq_id seq_id;
// mirostat sampler state // mirostat sampler state
float mirostat_mu; float mirostat_mu;
llama_context * ctx; // TMP llama_sampling * smpl;
llama_grammar * grammar; llama_grammar * grammar;
// internal // internal
@ -91,7 +89,7 @@ struct llama_sampling_context {
#include "common.h" #include "common.h"
// Create a new sampling context instance. // Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id); struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);
void llama_sampling_free(struct llama_sampling_context * ctx); void llama_sampling_free(struct llama_sampling_context * ctx);
@ -100,9 +98,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar // - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx); void llama_sampling_reset(llama_sampling_context * ctx);
// Set the sampler seed
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
// Copy the sampler context // Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);

View file

@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
} }
} }
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_batch_free(batch); llama_batch_free(batch);

View file

@ -64,6 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_batch = std::max(n_predict, n_parallel); ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_sampling * smpl = llama_get_sampling(ctx);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -180,13 +181,13 @@ int main(int argc, char ** argv) {
const float top_p = 0.9f; const float top_p = 0.9f;
const float temp = 0.4f; const float temp = 0.4f;
llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sampling_top_k(smpl, &candidates_p, top_k, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sampling_top_p(smpl, &candidates_p, top_p, 1);
llama_sample_temp (ctx, &candidates_p, temp); llama_sampling_temp (smpl, &candidates_p, temp);
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@ -244,12 +245,13 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx); llama_print_timings(ctx, smpl, nullptr);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_batch_free(batch); llama_batch_free(batch);
llama_sampling_free(smpl);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View file

@ -258,7 +258,7 @@ int main(int argc, char ** argv) {
} }
// clean up // clean up
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_batch_free(batch); llama_batch_free(batch);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View file

@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View file

@ -9,7 +9,7 @@
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) { static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
std::vector<std::vector<float>> result; std::vector<std::vector<float>> result;
const llama_model * mdl = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const std::string input_string = instruction + sentences[i]; const std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false); std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
const int32_t n_toks = inputs.size(); const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = "" // GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl)); // inputs.push_back(llama_token_eos(model));
// we want to ignore instruction tokens for mean pooling // we want to ignore instruction tokens for mean pooling
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size(); const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
#ifdef GRIT_DEBUG #ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample // debug tokens - should be matching as referenced in the GritLM sample
@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_decode(ctx, batch); llama_decode(ctx, batch);
// get embedding dimensions // get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl); uint64_t n_embd = llama_n_embd(model);
// allocate embedding output // allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f); std::vector<float> emb_unorm(n_embd, 0.0f);
@ -95,8 +95,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::string result; std::string result;
const llama_model * mdl = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl); llama_sampling * smpl = llama_get_sampling(ctx);
llama_token eos_token = llama_token_eos(model);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false); llama_set_embeddings(ctx, false);
@ -104,7 +105,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
int32_t i_current_token = 0; int32_t i_current_token = 0;
while (true) { while (true) {
@ -118,14 +119,14 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_decode(ctx, bat); llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl)); auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
auto n_candidates = (int32_t)candidates.size(); auto n_candidates = (int32_t)candidates.size();
for (int32_t token = 0; token < n_candidates; token++) { for (int32_t token = 0; token < n_candidates; token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f }; candidates[token] = llama_token_data{ token, logits[token], 0.0f };
} }
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p); llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p);
if (token == eos_token) { if (token == eos_token) {
break; break;
} }
@ -167,10 +168,10 @@ int main(int argc, char * argv[]) {
llama_backend_init(); llama_backend_init();
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
// create generation context // create generation context
llama_context * ctx = llama_new_context_with_model(mdl, cparams); llama_context * ctx = llama_new_context_with_model(model, cparams);
// ### Embedding/Representation ### // ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic // samples taken from: https://github.com/ContextualAI/gritlm#basic
@ -191,7 +192,7 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(mdl); const int n_embd = llama_n_embd(model);
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
@ -212,7 +213,7 @@ int main(int argc, char * argv[]) {
} }
llama_free(ctx); llama_free(ctx);
llama_free_model(mdl); llama_free_model(model);
llama_backend_free(); llama_backend_free();
return 0; return 0;

View file

@ -638,7 +638,7 @@ int main(int argc, char ** argv) {
g_collector.save_imatrix(); g_collector.save_imatrix();
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View file

@ -34,6 +34,7 @@
static llama_context ** g_ctx; static llama_context ** g_ctx;
static llama_model ** g_model; static llama_model ** g_model;
static llama_sampling_context ** g_ctx_sampling;
static gpt_params * g_params; static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens; static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss; static std::ostringstream * g_output_ss;
@ -93,7 +94,7 @@ static void sigint_handler(int signo) {
} else { } else {
console::cleanup(); console::cleanup();
printf("\n"); printf("\n");
llama_print_timings(*g_ctx); llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl, (*g_ctx_sampling)->grammar);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130); _exit(130);
} }
@ -171,11 +172,13 @@ int main(int argc, char ** argv) {
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
llama_model * model; llama_model * model = nullptr;
llama_context * ctx; llama_context * ctx = nullptr;
llama_sampling_context * ctx_sampling = nullptr;
g_model = &model; g_model = &model;
g_ctx = &ctx; g_ctx = &ctx;
g_ctx_sampling = &ctx_sampling;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__); LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@ -346,7 +349,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict
@ -635,7 +638,7 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
} }
llama_print_timings(ctx); llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
llama_free(ctx); llama_free(ctx);

View file

@ -1434,7 +1434,7 @@ int main(int argc, char ** argv) {
fflush(p_err->fout); fflush(p_err->fout);
} }
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_free(ctx); llama_free(ctx);
} }

View file

@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG_TEE("\n"); LOG_TEE("\n");
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama));
if (!ctx_sampling) { if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1); exit(1);
@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
// process the prompt // process the prompt
process_prompt(ctx_llava, image_embed, &params, params.prompt); process_prompt(ctx_llava, image_embed, &params, params.prompt);
llama_print_timings(ctx_llava->ctx_llama); llama_print_timings(ctx_llava->ctx_llama, nullptr, nullptr);
llava_image_embed_free(image_embed); llava_image_embed_free(image_embed);
ctx_llava->model = NULL; ctx_llava->model = NULL;
llava_free(ctx_llava); llava_free(ctx_llava);
@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
// process the prompt // process the prompt
process_prompt(ctx_llava, image_embed, &params, params.prompt); process_prompt(ctx_llava, image_embed, &params, params.prompt);
llama_print_timings(ctx_llava->ctx_llama); llama_print_timings(ctx_llava->ctx_llama, nullptr, nullptr);
llava_image_embed_free(image_embed); llava_image_embed_free(image_embed);
ctx_llava->model = NULL; ctx_llava->model = NULL;
llava_free(ctx_llava); llava_free(ctx_llava);

View file

@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context // target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
// verification n-grams // verification n-grams
std::vector<ngram_data> ngrams_cur(G); std::vector<ngram_data> ngrams_cur(G);
@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("n_accept = %d\n", n_accept);
llama_print_timings(ctx); llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar);
llama_kv_cache_view_free(&kvc_view); llama_kv_cache_view_free(&kvc_view);
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);

View file

@ -106,7 +106,7 @@ int main(int argc, char ** argv){
bool has_eos = false; bool has_eos = false;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
std::vector<llama_token> draft; std::vector<llama_token> draft;
@ -241,7 +241,7 @@ int main(int argc, char ** argv){
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ntarget:\n"); LOG_TEE("\ntarget:\n");
llama_print_timings(ctx); llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar);
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
llama_batch_free(batch_tgt); llama_batch_free(batch_tgt);

View file

@ -33,6 +33,7 @@
static llama_context ** g_ctx; static llama_context ** g_ctx;
static llama_model ** g_model; static llama_model ** g_model;
static llama_sampling_context ** g_ctx_sampling;
static gpt_params * g_params; static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens; static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss; static std::ostringstream * g_output_ss;
@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
} else { } else {
console::cleanup(); console::cleanup();
printf("\n"); printf("\n");
llama_print_timings(*g_ctx); llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl, (*g_ctx_sampling)->grammar);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130); _exit(130);
} }
@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) { static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
llama_chat_msg new_msg{role, content}; llama_chat_msg new_msg{role, content};
auto formatted = llama_chat_format_single( auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content}); chat_msgs.push_back({role, content});
return formatted; return formatted;
} }
@ -197,12 +197,16 @@ int main(int argc, char ** argv) {
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
llama_model * model; llama_model * model = nullptr;
llama_context * ctx; llama_context * ctx = nullptr;
llama_context * ctx_guidance = NULL; llama_context * ctx_guidance = nullptr;
llama_sampling_context * ctx_sampling = nullptr;
std::vector<llama_chat_msg> chat_msgs; std::vector<llama_chat_msg> chat_msgs;
g_model = &model; g_model = &model;
g_ctx = &ctx; g_ctx = &ctx;
g_ctx_sampling = &ctx_sampling;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__); LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@ -527,7 +531,7 @@ int main(int argc, char ** argv) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
} }
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
if (!ctx_sampling) { if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1); exit(1);
@ -975,7 +979,7 @@ int main(int argc, char ** argv) {
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
} }
llama_print_timings(ctx); llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
if (ctx_guidance) { llama_free(ctx_guidance); } if (ctx_guidance) { llama_free(ctx_guidance); }

View file

@ -51,6 +51,7 @@ static std::vector<std::string> k_prompts = {
struct client { struct client {
~client() { ~client() {
if (ctx_sampling) { if (ctx_sampling) {
llama_sampling_free(ctx_sampling->smpl);
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
} }
} }
@ -161,7 +162,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) { for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i]; auto & client = clients[i];
client.id = i; client.id = i;
client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i); client.ctx_sampling = llama_sampling_init(params.sparams, llama_sampling_init(llama_n_vocab(model)));
} }
std::vector<llama_token> tokens_system; std::vector<llama_token> tokens_system;
@ -371,7 +372,7 @@ int main(int argc, char ** argv) {
} }
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
@ -413,7 +414,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
llama_print_timings(ctx); // TODO: print sampling/grammar timings for all clients
llama_print_timings(ctx, nullptr, nullptr);
llama_batch_free(batch); llama_batch_free(batch);

View file

@ -80,12 +80,13 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1; return 1;
} }
llama_sampling * smpl = llama_get_sampling(ctx);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true); tokens_list = ::llama_tokenize(ctx, params.prompt, true);
@ -230,7 +231,7 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
@ -267,7 +268,7 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
fprintf(stderr, "\n"); fprintf(stderr, "\n");

View file

@ -2054,7 +2054,7 @@ int main(int argc, char ** argv) {
results = perplexity(ctx, params, n_ctx); results = perplexity(ctx, params, n_ctx);
} }
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
write_logfile(ctx, params, model, results); write_logfile(ctx, params, model, results);
llama_free(ctx); llama_free(ctx);

View file

@ -292,7 +292,7 @@ int main(int argc, char ** argv) {
} }
// clean up // clean up
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
llama_backend_free(); llama_backend_free();

View file

@ -37,6 +37,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_get_sampling(ctx);
// tokenize prompt // tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true); auto tokens = llama_tokenize(ctx, params.prompt, true);
@ -72,7 +74,7 @@ int main(int argc, char ** argv) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx, &candidates_p); auto next_token = llama_sampling_sample(smpl, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx, next_token); auto next_token_str = llama_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
@ -95,6 +97,8 @@ int main(int argc, char ** argv) {
// make new context // make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl2 = llama_get_sampling(ctx2);
printf("\nsecond run: %s", params.prompt.c_str()); printf("\nsecond run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file // load state (rng, logits, embedding and kv_cache) from file
@ -128,7 +132,7 @@ int main(int argc, char ** argv) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx2, &candidates_p); auto next_token = llama_sampling_sample(smpl2, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx2, next_token); auto next_token_str = llama_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
@ -153,7 +157,9 @@ int main(int argc, char ** argv) {
} }
// make new context // make new context
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
llama_sampling * smpl3 = llama_get_sampling(ctx3);
printf("\nsingle seq run: %s", params.prompt.c_str()); printf("\nsingle seq run: %s", params.prompt.c_str());
@ -216,7 +222,7 @@ int main(int argc, char ** argv) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx3, &candidates_p); auto next_token = llama_sampling_sample(smpl3, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx3, next_token); auto next_token_str = llama_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());

View file

@ -664,6 +664,7 @@ struct server_context {
// Clear any sampling context // Clear any sampling context
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.ctx_sampling != nullptr) { if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling->smpl);
llama_sampling_free(slot.ctx_sampling); llama_sampling_free(slot.ctx_sampling);
} }
} }
@ -1088,9 +1089,11 @@ struct server_context {
{ {
if (slot.ctx_sampling != nullptr) { if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling->smpl);
llama_sampling_free(slot.ctx_sampling); llama_sampling_free(slot.ctx_sampling);
} }
slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id);
slot.ctx_sampling = llama_sampling_init(slot.sparams, llama_sampling_init(llama_n_vocab(model)));
if (slot.ctx_sampling == nullptr) { if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@ -2402,7 +2405,7 @@ struct server_context {
// Make sure at least n_probs top tokens are at the front of the vector: // Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_valid) { if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0); llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
} }
if (slot.sparams.temp == 0.0f) { if (slot.sparams.temp == 0.0f) {

View file

@ -55,6 +55,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling * smpl = llama_get_sampling(ctx);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
@ -123,7 +125,7 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@ -160,7 +162,7 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx); llama_print_timings(ctx, nullptr, nullptr);
fprintf(stderr, "\n"); fprintf(stderr, "\n");

View file

@ -174,8 +174,8 @@ int main(int argc, char ** argv) {
// used to determine end of generation // used to determine end of generation
bool has_eos = false; bool has_eos = false;
// target model sampling context // target model sampling context (reuse the llama_context's sampling instance)
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx_tgt));
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
@ -186,7 +186,8 @@ int main(int argc, char ** argv) {
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s); // allocate llama_sampling for each draft sequence
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, llama_sampling_init(llama_n_vocab(model_dft)));
} }
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@ -230,8 +231,10 @@ int main(int argc, char ** argv) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt); llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt);
float p_tgt = 0, p_dft = 0;
float p_tgt = 0.0f;
float p_dft = 0.0f;
// GGML_ASSERT(dist_tgt.size() == dist_dft.size()); // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
@ -327,7 +330,7 @@ int main(int argc, char ** argv) {
// all drafted tokens were rejected // all drafted tokens were rejected
// sample from the target model // sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n"); LOG("all drafted tokens were rejected, sampling from residual distribution\n");
token_id = llama_sample_token(ctx_tgt, &dist_tgt); token_id = llama_sampling_sample(ctx_sampling->smpl, &dist_tgt);
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id); token_str = llama_token_to_piece(ctx_tgt, token_id);
} }
@ -589,13 +592,15 @@ int main(int argc, char ** argv) {
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
LOG_TEE("\ndraft:\n"); LOG_TEE("\ndraft:\n");
llama_print_timings(ctx_dft); // TODO: print sampling/grammar timings for all drafts
llama_print_timings(ctx_dft, nullptr, nullptr);
LOG_TEE("\ntarget:\n"); LOG_TEE("\ntarget:\n");
llama_print_timings(ctx_tgt); llama_print_timings(ctx_tgt, ctx_sampling->smpl, ctx_sampling->grammar);
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
llama_sampling_free(drafts[s].ctx_sampling->smpl);
llama_sampling_free(drafts[s].ctx_sampling); llama_sampling_free(drafts[s].ctx_sampling);
} }

View file

@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8 #define LLAMA_SESSION_VERSION 7
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1 #define LLAMA_STATE_SEQ_VERSION 1
@ -394,16 +394,22 @@ extern "C" {
uint32_t value; // Unicode code point or rule ID uint32_t value; // Unicode code point or rule ID
} llama_grammar_element; } llama_grammar_element;
// sampling types
struct llama_sampling;
// performance timing information // performance timing information
struct llama_timings { struct llama_timings {
double t_start_ms; double t_start_ms;
double t_end_ms; double t_end_ms;
double t_load_ms; double t_load_ms;
double t_sample_ms; double t_sampling_ms;
double t_grammar_ms;
double t_p_eval_ms; double t_p_eval_ms;
double t_eval_ms; double t_eval_ms;
int32_t n_sample; int32_t n_sampling;
int32_t n_grammar_sample;
int32_t n_grammar_accept;
int32_t n_p_eval; int32_t n_p_eval;
int32_t n_eval; int32_t n_eval;
}; };
@ -454,7 +460,8 @@ extern "C" {
LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void); LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
@ -1028,85 +1035,87 @@ extern "C" {
// Sampling functions // Sampling functions
// //
// Sets the current rng seed. // TODO: args become llama_sampling_params
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); LLAMA_API struct llama_sampling * llama_sampling_init(int32_t n_vocab);
LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id), LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
"temporary API, until llama_sampling_context is implemented, do not use");
// Sets the current rng seed.
LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_repetition_penalties( LLAMA_API void llama_sampling_repetition_penalties(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t penalty_last_n, size_t penalty_last_n,
float penalty_repeat, float penalty_repeat,
float penalty_freq, float penalty_freq,
float penalty_present); float penalty_present);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param logits Logits extracted from the original generation context. /// @param logits Logits extracted from the original generation context.
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sample_apply_guidance( LLAMA_API void llama_sampling_apply_guidance(
struct llama_context * ctx, struct llama_sampling * smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale); float scale);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax( LLAMA_API void llama_sampling_softmax(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates); llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k( LLAMA_API void llama_sampling_top_k(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
int32_t k, int32_t k,
size_t min_keep); size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p( LLAMA_API void llama_sampling_top_p(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float p, float p,
size_t min_keep); size_t min_keep);
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API void llama_sample_min_p( LLAMA_API void llama_sampling_min_p(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float p, float p,
size_t min_keep); size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free( LLAMA_API void llama_sampling_tail_free(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float z, float z,
size_t min_keep); size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical( LLAMA_API void llama_sampling_typical(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float p, float p,
size_t min_keep); size_t min_keep);
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
LLAMA_API void llama_sample_entropy( LLAMA_API void llama_sampling_entropy(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates_p, llama_token_data_array * candidates_p,
float min_temp, float min_temp,
float max_temp, float max_temp,
float exponent_val); float exponent_val);
LLAMA_API void llama_sample_temp( LLAMA_API void llama_sampling_temp(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float temp); float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@ -1114,43 +1123,36 @@ extern "C" {
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat( LLAMA_API llama_token llama_sampling_sample_mirostat(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float tau, float tau,
float eta, float eta,
int32_t m, int32_t m,
float * mu); float * mu);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
LLAMA_API llama_token llama_sample_token_mirostat_v2( LLAMA_API llama_token llama_sampling_sample_mirostat_v2(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
float tau, float tau,
float eta, float eta,
float * mu); float * mu);
/// @details Selects the token with the highest probability. /// @details Selects the token with the highest probability.
/// Does not compute the token probabilities. Use llama_sample_softmax() instead. /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
LLAMA_API llama_token llama_sample_token_greedy( LLAMA_API llama_token llama_sampling_sample_greedy(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates); llama_token_data_array * candidates);
/// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx. /// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of smpl.
LLAMA_API llama_token llama_sample_token( LLAMA_API llama_token llama_sampling_sample(
struct llama_context * ctx, struct llama_sampling * smpl,
llama_token_data_array * candidates); llama_token_data_array * candidates);
/// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id].
LLAMA_API DEPRECATED(llama_token llama_sample_token_seq(
struct llama_context * ctx,
llama_token_data_array * candidates,
llama_seq_id seq_id),
"temporary API, until llama_sampling_context is implemented, do not use");
// //
// Model split // Model split
@ -1169,8 +1171,8 @@ extern "C" {
// Performance information // Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar);
LLAMA_API void llama_reset_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar);
// Print system information // Print system information
LLAMA_API const char * llama_print_system_info(void); LLAMA_API const char * llama_print_system_info(void);

View file

@ -438,7 +438,7 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
} }
void llama_grammar_free_impl(struct llama_grammar * grammar) { void llama_grammar_free_impl(struct llama_grammar * grammar) {
@ -446,7 +446,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
} }
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 }; llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t is = 0; is < result->stacks.size(); is++) {

View file

@ -11,6 +11,11 @@ struct llama_grammar {
// buffer for partially generated UTF-8 sequence from accepted tokens // buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8; llama_partial_utf8 partial_utf8;
mutable int64_t t_total_us;
mutable int32_t n_sample;
mutable int32_t n_accept;
}; };
struct llama_grammar * llama_get_grammar(struct llama_context * ctx); struct llama_grammar * llama_get_grammar(struct llama_context * ctx);

View file

@ -21,7 +21,15 @@ static void llama_log_softmax(float * array, size_t size) {
} }
} }
void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab) {
return new llama_sampling(n_vocab);
}
void llama_sampling_free_impl(struct llama_sampling * sampling) {
delete sampling;
}
void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) { if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL); seed = time(NULL);
} }
@ -29,7 +37,7 @@ void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) {
smpl.rng.seed(seed); smpl.rng.seed(seed);
} }
void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
GGML_ASSERT(candidates->size > 0); GGML_ASSERT(candidates->size > 0);
// Sort the logits in descending order // Sort the logits in descending order
@ -54,7 +62,7 @@ void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_dat
} }
} }
void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)candidates->size) { // if (k >= (int32_t)candidates->size) {
// return; // return;
@ -129,12 +137,12 @@ void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_
candidates->size = k; candidates->size = k;
} }
void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p >= 1.0f) { if (p >= 1.0f) {
return; return;
} }
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Compute the cumulative probabilities // Compute the cumulative probabilities
float cum_sum = 0.0f; float cum_sum = 0.0f;
@ -155,7 +163,7 @@ void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_arra
candidates->size = last_idx; candidates->size = last_idx;
} }
void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) { if (p <= 0.0f || !candidates->size) {
return; return;
} }
@ -210,12 +218,12 @@ void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_
} }
} }
void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) { if (z >= 1.0f || candidates->size <= 2) {
return; return;
} }
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Compute the first and second derivatives // Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1); std::vector<float> first_derivatives(candidates->size - 1);
@ -264,7 +272,7 @@ void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_
candidates->size = last_idx; candidates->size = last_idx;
} }
void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
// Reference implementation: // Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) { if (p >= 1.0f) {
@ -272,7 +280,7 @@ void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_ar
} }
// Compute the softmax of logits and calculate entropy // Compute the softmax of logits and calculate entropy
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
float entropy = 0.0f; float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
@ -322,7 +330,7 @@ void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_ar
candidates->sorted = false; candidates->sorted = false;
} }
void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
// no need to do anything if there is only one (or zero) candidates // no need to do anything if there is only one (or zero) candidates
if(candidates->size <= 1) { if(candidates->size <= 1) {
return; return;
@ -331,7 +339,7 @@ void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_ar
// Calculate maximum possible entropy // Calculate maximum possible entropy
float max_entropy = -logf(1.0f / candidates->size); float max_entropy = -logf(1.0f / candidates->size);
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Calculate entropy of the softmax probabilities // Calculate entropy of the softmax probabilities
float entropy = 0.0f; float entropy = 0.0f;
@ -383,13 +391,13 @@ void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_ar
#endif #endif
} }
void llama_sample_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) {
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
candidates->data[i].logit /= temp; candidates->data[i].logit /= temp;
} }
} }
void llama_sample_repetition_penalties_impl( void llama_sampling_repetition_penalties_impl(
struct llama_sampling & /*smpl*/, struct llama_sampling & /*smpl*/,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
@ -430,7 +438,7 @@ void llama_sample_repetition_penalties_impl(
candidates->sorted = false; candidates->sorted = false;
} }
void llama_sample_apply_guidance_impl( void llama_sampling_apply_guidance_impl(
struct llama_sampling & smpl, struct llama_sampling & smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
@ -448,10 +456,10 @@ void llama_sample_apply_guidance_impl(
} }
} }
llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
const int32_t n_vocab = float(smpl.n_vocab); const int32_t n_vocab = float(smpl.n_vocab);
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Estimate s_hat using the most probable m tokens // Estimate s_hat using the most probable m tokens
float s_hat = 0.0; float s_hat = 0.0;
@ -470,8 +478,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling // Sample the next word X using top-k sampling
llama_sample_top_k_impl(smpl, candidates, int(k), 1); llama_sampling_top_k_impl(smpl, candidates, int(k), 1);
llama_token X = llama_sample_token_impl(smpl, candidates); llama_token X = llama_sampling_sample_impl(smpl, candidates);
// Compute error as the difference between observed surprise and target surprise value // Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -486,8 +494,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama
return X; return X;
} }
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Truncate the words with surprise values greater than mu // Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -499,10 +507,10 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, ll
} }
// Normalize the probabilities of the remaining words // Normalize the probabilities of the remaining words
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
// Sample the next word X from the remaining words // Sample the next word X from the remaining words
llama_token X = llama_sample_token_impl(smpl, candidates); llama_token X = llama_sampling_sample_impl(smpl, candidates);
// Compute error as the difference between observed surprise and target surprise value // Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@ -517,7 +525,7 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, ll
return X; return X;
} }
llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) {
// Find max element // Find max element
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit < b.logit; return a.logit < b.logit;
@ -528,8 +536,8 @@ llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, lla
return result; return result;
} }
llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
llama_sample_softmax_impl(smpl, candidates); llama_sampling_softmax_impl(smpl, candidates);
std::vector<float> probs; std::vector<float> probs;
probs.reserve(candidates->size); probs.reserve(candidates->size);
@ -545,6 +553,6 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama
return result; return result;
} }
llama_token llama_sample_token_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(smpl, candidates, smpl.rng); return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng);
} }

View file

@ -3,29 +3,37 @@
#include "llama-impl.h" #include "llama-impl.h"
struct llama_sampling { struct llama_sampling {
llama_sampling(uint32_t seed, int32_t n_vocab) : rng(seed), n_vocab(n_vocab) {} llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
const int32_t n_vocab;
std::mt19937 rng; std::mt19937 rng;
const int32_t n_vocab; mutable int64_t t_total_us = 0;
mutable int32_t n_sample = 0;
}; };
// //
// internal API // internal API
// //
void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab);
void llama_sample_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); void llama_sampling_free_impl(struct llama_sampling * sampling);
void llama_sample_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
void llama_sample_repetition_penalties_impl( void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed);
void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sampling_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sampling_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp);
void llama_sampling_repetition_penalties_impl(
struct llama_sampling & smpl, struct llama_sampling & smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
@ -34,15 +42,15 @@ void llama_sample_repetition_penalties_impl(
float penalty_freq, float penalty_freq,
float penalty_present); float penalty_present);
void llama_sample_apply_guidance_impl( void llama_sampling_apply_guidance_impl(
struct llama_sampling & smpl, struct llama_sampling & smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale); float scale);
llama_token llama_sample_token_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); llama_token llama_sampling_sample_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
llama_token llama_sample_token_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);
llama_token llama_sample_token_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates);

View file

@ -2669,6 +2669,7 @@ struct llama_model {
struct llama_context { struct llama_context {
llama_context(const llama_model & model) llama_context(const llama_model & model)
: model(model) : model(model)
, sampling(llama_n_vocab(&model))
, grammar() , grammar()
, t_start_us(model.t_start_us) , t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {} , t_load_us(model.t_load_us) {}
@ -2686,12 +2687,11 @@ struct llama_context {
const struct llama_model & model; const struct llama_model & model;
struct llama_cparams cparams; struct llama_cparams cparams;
struct llama_sampling sampling;
struct llama_grammar grammar; struct llama_grammar grammar;
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
struct llama_control_vector cvec; struct llama_control_vector cvec;
std::vector<struct llama_sampling> sampling; // sampling context for each sequence
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_t> backends; std::vector<ggml_backend_t> backends;
@ -2707,14 +2707,12 @@ struct llama_context {
mutable int64_t t_start_us; mutable int64_t t_start_us;
mutable int64_t t_load_us; mutable int64_t t_load_us;
mutable int64_t t_sample_us = 0;
mutable int64_t t_p_eval_us = 0; mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0; mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0; mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0; mutable int64_t n_queued_tokens = 0;
mutable int32_t n_sample = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls mutable int32_t n_eval = 0; // number of eval calls
@ -16542,10 +16540,7 @@ struct llama_context * llama_new_context_with_model(
ctx->abort_callback = params.abort_callback; ctx->abort_callback = params.abort_callback;
ctx->abort_callback_data = params.abort_callback_data; ctx->abort_callback_data = params.abort_callback_data;
ctx->sampling.reserve(cparams.n_seq_max); llama_sampling_set_rng_seed_impl(ctx->sampling, params.seed);
for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
ctx->sampling.emplace_back(params.seed, llama_n_vocab(model));
}
ctx->logits_all = params.logits_all; ctx->logits_all = params.logits_all;
@ -16827,6 +16822,10 @@ const struct llama_model * llama_get_model(const struct llama_context * ctx) {
return &ctx->model; return &ctx->model;
} }
struct llama_sampling * llama_get_sampling(struct llama_context * ctx) {
return &ctx->sampling;
}
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) { const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
return &ctx->model.vocab; return &ctx->model.vocab;
} }
@ -17322,7 +17321,6 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes. // for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_n_rng = sizeof(uint32_t);
const size_t s_rng_size = sizeof(size_t); const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_rng = LLAMA_MAX_RNG_STATE;
const size_t s_n_outputs = sizeof(size_t); const size_t s_n_outputs = sizeof(size_t);
@ -17342,8 +17340,8 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = ( const size_t s_total = (
+ s_n_rng + s_rng_size
+ cparams.n_seq_max*(s_rng_size + s_rng) + s_rng
+ s_n_outputs + s_n_outputs
+ s_output_pos + s_output_pos
+ s_logits_size + s_logits_size
@ -17360,7 +17358,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
); );
// on session change it is very likely that the state size has changed - so we need to update this function // on session change it is very likely that the state size has changed - so we need to update this function
static_assert(LLAMA_SESSION_VERSION == 8, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
return s_total; return s_total;
} }
@ -17423,22 +17421,16 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
// copy rngs // copy rngs
{ {
const uint32_t n_rng = ctx->sampling.size(); std::ostringstream rng_ss;
rng_ss << ctx->sampling.rng;
data_ctx->write(&n_rng, sizeof(n_rng)); const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();
for (const auto & smpl : ctx->sampling) { GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
std::ostringstream rng_ss;
rng_ss << smpl.rng;
const std::string & rng_str = rng_ss.str(); data_ctx->write(&rng_size, sizeof(rng_size));
const size_t rng_size = rng_str.size(); data_ctx->write(rng_str.data(), rng_size);
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
data_ctx->write(&rng_size, sizeof(rng_size));
data_ctx->write(rng_str.data(), rng_size);
}
} }
// copy outputs // copy outputs
@ -17588,24 +17580,17 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
// set rngs // set rngs
{ {
uint32_t n_rng; size_t rng_size;
memcpy(&n_rng, inp, sizeof(n_rng)); inp += sizeof(n_rng); memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
GGML_ASSERT(n_rng == ctx->cparams.n_seq_max); GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
for (auto & smpl : ctx->sampling) { std::string rng_str((const char *)inp, rng_size); inp += rng_size;
size_t rng_size;
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); std::istringstream rng_ss(rng_str);
rng_ss >> ctx->sampling.rng;
std::string rng_str((const char *)inp, rng_size); inp += rng_size; GGML_ASSERT(!rng_ss.fail());
std::istringstream rng_ss(rng_str);
rng_ss >> smpl.rng;
GGML_ASSERT(!rng_ss.fail());
}
} }
// set output ids // set output ids
@ -18978,11 +18963,14 @@ void llama_grammar_sample(
const struct llama_grammar * grammar, const struct llama_grammar * grammar,
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token_data_array * candidates) { llama_token_data_array * candidates) {
time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling time_meas tm(grammar->t_total_us); // TODO: measure grammar time separately from sampling
llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates); llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates);
grammar->n_sample++;
} }
// deprecated
void llama_sample_grammar( void llama_sample_grammar(
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates, llama_token_data_array * candidates,
@ -18994,140 +18982,148 @@ void llama_grammar_accept_token(
struct llama_grammar * grammar, struct llama_grammar * grammar,
struct llama_context * ctx, struct llama_context * ctx,
llama_token token) { llama_token token) {
time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling time_meas tm(grammar->t_total_us); // TODO: measure grammar time separately from sampling
llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token); llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token);
grammar->n_accept++;
} }
// //
// sampling // sampling
// //
void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { struct llama_sampling * llama_sampling_init(int32_t n_vocab) {
llama_set_rng_seed_impl(ctx->sampling[0], seed); return llama_sampling_init_impl(n_vocab);
} }
void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id seq_id) { void llama_sampling_free(struct llama_sampling * smpl) {
llama_set_rng_seed_impl(ctx->sampling[seq_id], seed); if (smpl == nullptr) {
return;
}
llama_sampling_free_impl(smpl);
} }
void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
time_meas tm(ctx->t_sample_us); llama_sampling_set_rng_seed_impl(*smpl, seed);
llama_sample_softmax_impl(ctx->sampling[0], candidates);
} }
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_top_k_impl(ctx->sampling[0], candidates, k, min_keep); llama_sampling_softmax_impl(*smpl, candidates);
} }
void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_top_p_impl(ctx->sampling[0], candidates, p, min_keep); llama_sampling_top_k_impl(*smpl, candidates, k, min_keep);
} }
void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_min_p_impl(ctx->sampling[0], candidates, p, min_keep); llama_sampling_top_p_impl(*smpl, candidates, p, min_keep);
} }
void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_tail_free_impl(ctx->sampling[0], candidates, z, min_keep); llama_sampling_min_p_impl(*smpl, candidates, p, min_keep);
} }
void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_typical_impl(ctx->sampling[0], candidates, p, min_keep); llama_sampling_tail_free_impl(*smpl, candidates, z, min_keep);
} }
void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_entropy_impl(ctx->sampling[0], candidates_p, min_temp, max_temp, exponent_val); llama_sampling_typical_impl(*smpl, candidates, p, min_keep);
} }
void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { void llama_sampling_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_temp_impl(ctx->sampling[0], candidates_p, temp); llama_sampling_entropy_impl(*smpl, candidates_p, min_temp, max_temp, exponent_val);
} }
void llama_sample_repetition_penalties( void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float temp) {
struct llama_context * ctx, time_meas tm(smpl->t_total_us);
llama_sampling_temp_impl(*smpl, candidates_p, temp);
}
void llama_sampling_repetition_penalties(
struct llama_sampling * smpl,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t penalty_last_n, size_t penalty_last_n,
float penalty_repeat, float penalty_repeat,
float penalty_freq, float penalty_freq,
float penalty_present) { float penalty_present) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_repetition_penalties_impl(ctx->sampling[0], candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
} }
void llama_sample_apply_guidance( void llama_sampling_apply_guidance(
struct llama_context * ctx, struct llama_sampling * smpl,
float * logits, float * logits,
float * logits_guidance, float * logits_guidance,
float scale) { float scale) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
llama_sample_apply_guidance_impl(ctx->sampling[0], logits, logits_guidance, scale); llama_sampling_apply_guidance_impl(*smpl, logits, logits_guidance, scale);
} }
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
auto res = llama_sample_token_mirostat_impl(ctx->sampling[0], candidates, tau, eta, m, mu); auto res = llama_sampling_sample_mirostat_impl(*smpl, candidates, tau, eta, m, mu);
ctx->n_sample++; smpl->n_sample++;
return res; return res;
} }
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
auto res = llama_sample_token_mirostat_v2_impl(ctx->sampling[0], candidates, tau, eta, mu); auto res = llama_sampling_sample_mirostat_v2_impl(*smpl, candidates, tau, eta, mu);
ctx->n_sample++; smpl->n_sample++;
return res; return res;
} }
llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(ctx->t_sample_us); time_meas tm(smpl->t_total_us);
auto res = llama_sample_token_greedy_impl(ctx->sampling[0], candidates); auto res = llama_sampling_sample_greedy_impl(*smpl, candidates);
ctx->n_sample++; smpl->n_sample++;
return res; return res;
} }
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_seq(ctx, candidates, 0); time_meas tm(smpl->t_total_us);
}
llama_token llama_sample_token_seq(struct llama_context * ctx, llama_token_data_array * candidates, llama_seq_id seq_id) { auto res = llama_sampling_sample_impl(*smpl, candidates);
GGML_ASSERT(seq_id >= 0 && seq_id < (int32_t) ctx->cparams.n_seq_max);
time_meas tm(ctx->t_sample_us); smpl->n_sample++;
auto res = llama_sample_token_impl(ctx->sampling[seq_id], candidates);
ctx->n_sample++;
return res; return res;
} }
//
// model split
//
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@ -19152,30 +19148,29 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
return 0; return 0;
} }
struct llama_timings llama_get_timings(struct llama_context * ctx) { void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar) {
struct llama_timings result = { const llama_timings timings = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
/*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_end_ms =*/ 1.00 * ggml_time_ms(),
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us, /*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
/*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_total_us : ctx->sampling.t_total_us),
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_grammar_ms =*/ grammar ? (1e-3 * grammar->t_total_us) : 0.0,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
/*.n_sample =*/ std::max(1, ctx->n_sample), /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : ctx->sampling.n_sample),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_grammar_sample =*/ std::max(0, grammar ? grammar->n_sample : 0),
/*.n_eval =*/ std::max(1, ctx->n_eval), /*.n_grammar_accept =*/ std::max(0, grammar ? grammar->n_accept : 0),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval),
}; };
return result;
}
void llama_print_timings(struct llama_context * ctx) {
const llama_timings timings = llama_get_timings(ctx);
LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("\n");
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_grammar_ms, timings.n_grammar_sample, timings.t_grammar_ms / timings.n_grammar_sample, 1e3 / timings.t_grammar_ms * timings.n_grammar_sample);
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
@ -19183,11 +19178,18 @@ void llama_print_timings(struct llama_context * ctx) {
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
} }
void llama_reset_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar) {
ctx->t_start_us = ggml_time_us(); ctx->t_start_us = ggml_time_us();
ctx->t_eval_us = ctx->n_eval = 0; ctx->t_eval_us = ctx->n_eval = 0;
ctx->t_sample_us = ctx->n_sample = 0;
ctx->t_p_eval_us = ctx->n_p_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0;
if (smpl) {
smpl->t_total_us = smpl->n_sample = 0;
}
if (grammar) {
grammar->t_total_us = grammar->n_sample = grammar->n_accept = 0;
}
} }
const char * llama_print_system_info(void) { const char * llama_print_system_info(void) {
@ -19233,21 +19235,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
1.0e-3 * ctx->t_eval_us / ctx->n_eval); 1.0e-3 * ctx->t_eval_us / ctx->n_eval);
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
1.0e-3 * ctx->t_sample_us / ctx->n_sample);
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us);
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
1.0e6 * ctx->n_eval / ctx->t_eval_us); 1.0e6 * ctx->n_eval / ctx->t_eval_us);
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
1.0e6 * ctx->n_sample / ctx->t_sample_us);
} }
// For internal test use // For internal test use

View file

@ -1,5 +1,5 @@
#include "ggml.h" #include "ggml.h"
#include "llama-sampling.h" #include "llama.h"
#ifdef NDEBUG #ifdef NDEBUG
#undef NDEBUG #undef NDEBUG
@ -20,7 +20,7 @@ static void dump(const llama_token_data_array * candidates) {
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) { static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -30,9 +30,9 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax_impl(smpl, &candidates_p); llama_sampling_softmax(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_top_k_impl(smpl, &candidates_p, k, 1); llama_sampling_top_k(smpl, &candidates_p, k, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -43,7 +43,7 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -53,9 +53,9 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax_impl(smpl, &candidates_p); llama_sampling_softmax(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_top_p_impl(smpl, &candidates_p, p, 1); llama_sampling_top_p(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -66,7 +66,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) { static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -77,7 +77,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_tail_free_impl(smpl, &candidates_p, z, 1); llama_sampling_tail_free(smpl, &candidates_p, z, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -88,7 +88,7 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -99,9 +99,9 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_min_p_impl(smpl, &candidates_p, p, 1); llama_sampling_min_p(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_softmax_impl(smpl, &candidates_p); llama_sampling_softmax(smpl, &candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
@ -111,7 +111,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -122,7 +122,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_typical_impl(smpl, &candidates_p, p, 1); llama_sampling_typical(smpl, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -138,7 +138,7 @@ static void test_repetition_penalties(
GGML_ASSERT(probs.size() == expected_probs.size()); GGML_ASSERT(probs.size() == expected_probs.size());
const size_t n_vocab = probs.size(); const size_t n_vocab = probs.size();
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -148,10 +148,10 @@ static void test_repetition_penalties(
} }
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax_impl(smpl, &candidates_p); llama_sampling_softmax(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); llama_sampling_repetition_penalties(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
llama_sample_softmax_impl(smpl, &candidates_p); llama_sampling_softmax(smpl, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
GGML_ASSERT(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
@ -162,7 +162,7 @@ static void test_repetition_penalties(
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
) { ) {
llama_sampling smpl(1234, n_vocab); llama_sampling * smpl = llama_sampling_init(n_vocab);
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -178,16 +178,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
for (auto s : samplers_sequence) { for (auto s : samplers_sequence) {
switch (s){ switch (s){
case 'k': llama_sample_top_k_impl(smpl, &candidates_p, top_k, 1); break; case 'k': llama_sampling_top_k(smpl, &candidates_p, top_k, 1); break;
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
case 'y': GGML_ASSERT(false && "typical test not implemented"); break; case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
case 'p': llama_sample_top_p_impl(smpl, &candidates_p, top_p, 1); break; case 'p': llama_sampling_top_p(smpl, &candidates_p, top_p, 1); break;
case 'm': llama_sample_min_p_impl(smpl, &candidates_p, min_p, 1); break; case 'm': llama_sampling_min_p(smpl, &candidates_p, min_p, 1); break;
case 't': GGML_ASSERT(false && "temperature test not implemented"); break; case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
default : GGML_ASSERT(false && "Unknown sampler"); break; default : GGML_ASSERT(false && "Unknown sampler"); break;
} }
llama_sample_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests llama_sampling_softmax(smpl, &candidates_p); // make sure tokens are sorted for tests
const int size = candidates_p.size; const int size = candidates_p.size;