diff --git a/common/common.cpp b/common/common.cpp index dbb724fbb..0ba948840 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2125,7 +2125,7 @@ std::tuple llama_init_from_gpt_par llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx); + llama_reset_timings(lctx, nullptr, nullptr); } return std::make_tuple(model, lctx); diff --git a/common/sampling.cpp b/common/sampling.cpp index f1f41daab..cd4accf8f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,12 +2,11 @@ #include -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) { +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) { struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; - result->seq_id = seq_id; - result->ctx = ctx; + result->smpl = smpl; result->grammar = nullptr; // if there is a grammar, parse it @@ -43,7 +42,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->n_valid = 0; - llama_sampling_set_rng_seed(result, params.seed); + llama_sampling_set_rng_seed(result->smpl, params.seed); return result; } @@ -79,13 +78,6 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->n_valid = 0; } -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = std::random_device{}(); - } - llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id); -} - void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -230,10 +222,13 @@ std::vector llama_sampling_types_from_chars(const std::strin // no reasons to expose this function in header static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, + struct llama_sampling_context * ctx_sampling, llama_token_data_array & cur_p, size_t min_keep) { + llama_sampling * smpl = ctx_sampling->smpl; + + const llama_sampling_params & params = ctx_sampling->params; + const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; @@ -246,18 +241,18 @@ static void sampler_queue( for (auto sampler_type : samplers_sequence) { switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break; + case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break; + case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break; + case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break; + case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range); - llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); + llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); } else { - llama_sample_temp(ctx_main, &cur_p, temp); + llama_sampling_temp(smpl, &cur_p, temp); } break; default : break; @@ -271,6 +266,8 @@ static llama_token llama_sampling_sample_impl( struct llama_context * ctx_cfg, const int idx, bool is_resampling) { + llama_sampling * smpl = ctx_sampling->smpl; + const llama_sampling_params & params = ctx_sampling->params; const float temp = params.temp; @@ -287,26 +284,26 @@ static llama_token llama_sampling_sample_impl( if (temp < 0.0) { // greedy sampling, with probs - llama_sample_softmax(ctx_main, &cur_p); + llama_sampling_softmax(smpl, &cur_p); id = cur_p.data[0].id; } else if (temp == 0.0) { // greedy sampling, no probs - id = llama_sample_token_greedy(ctx_main, &cur_p); + id = llama_sampling_sample_greedy(smpl, &cur_p); } else { if (mirostat == 1) { const int mirostat_m = 100; - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); + llama_sampling_temp(smpl, &cur_p, temp); + id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); } else if (mirostat == 2) { - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + llama_sampling_temp(smpl, &cur_p, temp); + id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { // temperature sampling size_t min_keep = std::max(1, params.min_keep); - sampler_queue(ctx_main, params, cur_p, min_keep); + sampler_queue(ctx_sampling, cur_p, min_keep); - id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id); + id = llama_sampling_sample(smpl, &cur_p); //{ // const int n_top = 10; @@ -315,11 +312,11 @@ static llama_token llama_sampling_sample_impl( // for (int i = 0; i < n_top; i++) { // const llama_token id = cur_p.data[i].id; // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); // } //} - //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); + //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str()); } } @@ -360,6 +357,8 @@ static llama_token_data_array llama_sampling_prepare_impl( const int idx, bool apply_grammar, std::vector * original_logits) { + llama_sampling * smpl = ctx_sampling->smpl; + const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -390,7 +389,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale); } cur.resize(n_vocab); @@ -407,7 +406,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - llama_sample_repetition_penalties(ctx_main, &cur_p, + llama_sampling_repetition_penalties(smpl, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); @@ -445,7 +444,7 @@ llama_token_data_array llama_sampling_prepare( const int idx, bool apply_grammar, std::vector * original_logits) { - return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); + return llama_sampling_prepare_impl(ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits); } void llama_sampling_accept( diff --git a/common/sampling.h b/common/sampling.h index f03eb3be3..5f0c14923 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -70,12 +70,10 @@ struct llama_sampling_context { // parameters that will be used for sampling llama_sampling_params params; - llama_seq_id seq_id; - // mirostat sampler state float mirostat_mu; - llama_context * ctx; // TMP + llama_sampling * smpl; llama_grammar * grammar; // internal @@ -91,7 +89,7 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl); void llama_sampling_free(struct llama_sampling_context * ctx); @@ -100,9 +98,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); -// Set the sampler seed -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); - // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 718f0a61a..9df2506c7 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { } } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_batch_free(batch); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8..fb71f578c 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,6 +64,7 @@ int main(int argc, char ** argv) { ctx_params.n_batch = std::max(n_predict, n_parallel); llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_sampling * smpl = llama_get_sampling(ctx); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -180,13 +181,13 @@ int main(int argc, char ** argv) { const float top_p = 0.9f; const float temp = 0.4f; - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temp (ctx, &candidates_p, temp); + llama_sampling_top_k(smpl, &candidates_p, top_k, 1); + llama_sampling_top_p(smpl, &candidates_p, top_p, 1); + llama_sampling_temp (smpl, &candidates_p, temp); - const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p); - //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -244,12 +245,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl, nullptr); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 1466e5b2b..0fdee8fd0 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -258,7 +258,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_batch_free(batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index c8a3016a4..e440d21c7 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { return 1; } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1e..98acb29d2 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -9,7 +9,7 @@ static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { std::vector> result; - const llama_model * mdl = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -18,16 +18,16 @@ static std::vector> encode(llama_context * ctx, const std::ve const std::string input_string = instruction + sentences[i]; - std::vector inputs = llama_tokenize(mdl, input_string, true, false); + std::vector inputs = llama_tokenize(model, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(mdl)); + // inputs.push_back(llama_token_eos(model)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size(); + const int32_t n_inst = llama_tokenize(model, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -51,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(mdl); + uint64_t n_embd = llama_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -95,8 +95,9 @@ static std::vector> encode(llama_context * ctx, const std::ve static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { std::string result; - const llama_model * mdl = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(mdl); + const llama_model * model = llama_get_model(ctx); + llama_sampling * smpl = llama_get_sampling(ctx); + llama_token eos_token = llama_token_eos(model); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -104,7 +105,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = llama_tokenize(mdl, prompt, false, true); + std::vector inputs = llama_tokenize(model, prompt, false, true); int32_t i_current_token = 0; while (true) { @@ -118,14 +119,14 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_decode(ctx, bat); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - auto candidates = std::vector(llama_n_vocab(mdl)); + auto candidates = std::vector(llama_n_vocab(model)); auto n_candidates = (int32_t)candidates.size(); for (int32_t token = 0; token < n_candidates; token++) { candidates[token] = llama_token_data{ token, logits[token], 0.0f }; } auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; - llama_token token = llama_sample_token_greedy(ctx, &candidates_p); + llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p); if (token == eos_token) { break; } @@ -167,10 +168,10 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(model, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -191,7 +192,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(mdl); + const int n_embd = llama_n_embd(model); const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); @@ -212,7 +213,7 @@ int main(int argc, char * argv[]) { } llama_free(ctx); - llama_free_model(mdl); + llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 574f5ed9c..c216a1b31 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -638,7 +638,7 @@ int main(int argc, char ** argv) { g_collector.save_imatrix(); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 91c777aad..c1359ffc8 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -34,6 +34,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling_context ** g_ctx_sampling; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +94,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl, (*g_ctx_sampling)->grammar); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -171,11 +172,13 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling_context * ctx_sampling = nullptr; g_model = &model; g_ctx = &ctx; + g_ctx_sampling = &ctx_sampling; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -346,7 +349,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); + ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx)); while (n_remain != 0 || params.interactive) { // predict @@ -635,7 +638,7 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a6497b6e0..0b07757b1 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1434,7 +1434,7 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_free(ctx); } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 14350bdf7..05d4f82a5 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama)); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -310,7 +310,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); @@ -327,7 +327,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index b8cddb660..6858835f1 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx)); // verification n-grams std::vector ngrams_cur(G); @@ -468,7 +468,7 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar); llama_kv_cache_view_free(&kvc_view); llama_sampling_free(ctx_sampling); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 3ccc6dfb4..6a1829537 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx)); std::vector draft; @@ -241,7 +241,7 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar); llama_sampling_free(ctx_sampling); llama_batch_free(batch_tgt); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1819df198..608f68356 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling_context ** g_ctx_sampling; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -105,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl, (*g_ctx_sampling)->grammar); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { llama_chat_msg new_msg{role, content}; - auto formatted = llama_chat_format_single( - model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); return formatted; } @@ -197,12 +197,16 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; - llama_context * ctx_guidance = NULL; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_context * ctx_guidance = nullptr; + llama_sampling_context * ctx_sampling = nullptr; + std::vector chat_msgs; + g_model = &model; g_ctx = &ctx; + g_ctx_sampling = &ctx_sampling; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -527,7 +531,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); + ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx)); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -975,7 +979,7 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl, ctx_sampling->grammar); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); if (ctx_guidance) { llama_free(ctx_guidance); } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d08e07ca2..512613c19 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -51,6 +51,7 @@ static std::vector k_prompts = { struct client { ~client() { if (ctx_sampling) { + llama_sampling_free(ctx_sampling->smpl); llama_sampling_free(ctx_sampling); } } @@ -161,7 +162,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i); + client.ctx_sampling = llama_sampling_init(params.sparams, llama_sampling_init(llama_n_vocab(model))); } std::vector tokens_system; @@ -371,7 +372,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); @@ -413,7 +414,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); - llama_print_timings(ctx); + // TODO: print sampling/grammar timings for all clients + llama_print_timings(ctx, nullptr, nullptr); llama_batch_free(batch); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1..b01920b93 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -80,12 +80,13 @@ int main(int argc, char ** argv) { GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); @@ -230,7 +231,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -267,7 +268,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index dbe445391..63411c2b2 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -2054,7 +2054,7 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); write_logfile(ctx, params, model, results); llama_free(ctx); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index eb89d16da..20aba9a8f 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -292,7 +292,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 00c2277ac..8b3518889 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -37,6 +37,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -72,7 +74,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx, &candidates_p); + auto next_token = llama_sampling_sample(smpl, &candidates_p); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -95,6 +97,8 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_sampling * smpl2 = llama_get_sampling(ctx2); + printf("\nsecond run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file @@ -128,7 +132,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx2, &candidates_p); + auto next_token = llama_sampling_sample(smpl2, &candidates_p); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -153,7 +157,9 @@ int main(int argc, char ** argv) { } // make new context - auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + llama_sampling * smpl3 = llama_get_sampling(ctx3); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -216,7 +222,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token = llama_sampling_sample(smpl3, &candidates_p); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3b5ed6d4d..672c71772 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -664,6 +664,7 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling->smpl); llama_sampling_free(slot.ctx_sampling); } } @@ -1088,9 +1089,11 @@ struct server_context { { if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling->smpl); llama_sampling_free(slot.ctx_sampling); } - slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id); + + slot.ctx_sampling = llama_sampling_init(slot.sparams, llama_sampling_init(llama_n_vocab(model))); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2402,7 +2405,7 @@ struct server_context { // Make sure at least n_probs top tokens are at the front of the vector: if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); + llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0); } if (slot.sparams.temp == 0.0f) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 69a92cf7d..02c3c07a5 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,6 +55,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize the prompt std::vector tokens_list; @@ -123,7 +125,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -160,7 +162,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr, nullptr); fprintf(stderr, "\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 569d95522..da348a8bf 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -174,8 +174,8 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; - // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0); + // target model sampling context (reuse the llama_context's sampling instance) + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx_tgt)); // draft sequence data std::vector drafts(n_seq_dft); @@ -186,7 +186,8 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s); + // allocate llama_sampling for each draft sequence + drafts[s].ctx_sampling = llama_sampling_init(params.sparams, llama_sampling_init(llama_n_vocab(model_dft))); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -230,8 +231,10 @@ int main(int argc, char ** argv) { // stochastic verification llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); - llama_sample_softmax(ctx_tgt, &dist_tgt); - float p_tgt = 0, p_dft = 0; + llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt); + + float p_tgt = 0.0f; + float p_dft = 0.0f; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); @@ -327,7 +330,7 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sample_token(ctx_tgt, &dist_tgt); + token_id = llama_sampling_sample(ctx_sampling->smpl, &dist_tgt); llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -589,13 +592,15 @@ int main(int argc, char ** argv) { LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ndraft:\n"); - llama_print_timings(ctx_dft); + // TODO: print sampling/grammar timings for all drafts + llama_print_timings(ctx_dft, nullptr, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt); + llama_print_timings(ctx_tgt, ctx_sampling->smpl, ctx_sampling->grammar); llama_sampling_free(ctx_sampling); for (int s = 0; s < n_seq_dft; ++s) { + llama_sampling_free(drafts[s].ctx_sampling->smpl); llama_sampling_free(drafts[s].ctx_sampling); } diff --git a/include/llama.h b/include/llama.h index adc1b72cc..57937ac10 100644 --- a/include/llama.h +++ b/include/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 7 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 @@ -394,16 +394,22 @@ extern "C" { uint32_t value; // Unicode code point or rule ID } llama_grammar_element; + // sampling types + struct llama_sampling; + // performance timing information struct llama_timings { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sample_ms; + double t_sampling_ms; + double t_grammar_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sample; + int32_t n_sampling; + int32_t n_grammar_sample; + int32_t n_grammar_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -454,7 +460,8 @@ extern "C" { LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); @@ -1028,85 +1035,87 @@ extern "C" { // Sampling functions // - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + // TODO: args become llama_sampling_params + LLAMA_API struct llama_sampling * llama_sampling_init(int32_t n_vocab); - LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id), - "temporary API, until llama_sampling_context is implemented, do not use"); + LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); + + // Sets the current rng seed. + LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed); /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); + LLAMA_API void llama_sampling_repetition_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param logits Logits extracted from the original generation context. /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale); + LLAMA_API void llama_sampling_apply_guidance( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax( - struct llama_context * ctx, + LLAMA_API void llama_sampling_softmax( + struct llama_sampling * smpl, llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k( - struct llama_context * ctx, - llama_token_data_array * candidates, - int32_t k, - size_t min_keep); + LLAMA_API void llama_sampling_top_k( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + int32_t k, + size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_top_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sample_min_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_min_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free( - struct llama_context * ctx, - llama_token_data_array * candidates, - float z, - size_t min_keep); + LLAMA_API void llama_sampling_tail_free( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float z, + size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_typical( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API void llama_sample_entropy( - struct llama_context * ctx, - llama_token_data_array * candidates_p, - float min_temp, - float max_temp, - float exponent_val); + LLAMA_API void llama_sampling_entropy( + struct llama_sampling * smpl, + llama_token_data_array * candidates_p, + float min_temp, + float max_temp, + float exponent_val); - LLAMA_API void llama_sample_temp( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp); + LLAMA_API void llama_sampling_temp( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float temp); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -1114,43 +1123,36 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - int32_t m, - float * mu); + LLAMA_API llama_token llama_sampling_sample_mirostat( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float tau, + float eta, + int32_t m, + float * mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); + LLAMA_API llama_token llama_sampling_sample_mirostat_v2( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); /// @details Selects the token with the highest probability. /// Does not compute the token probabilities. Use llama_sample_softmax() instead. - LLAMA_API llama_token llama_sample_token_greedy( - struct llama_context * ctx, - llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampling_sample_greedy( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx. - LLAMA_API llama_token llama_sample_token( - struct llama_context * ctx, - llama_token_data_array * candidates); - - /// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id]. - LLAMA_API DEPRECATED(llama_token llama_sample_token_seq( - struct llama_context * ctx, - llama_token_data_array * candidates, - llama_seq_id seq_id), - "temporary API, until llama_sampling_context is implemented, do not use"); + /// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of smpl. + LLAMA_API llama_token llama_sampling_sample( + struct llama_sampling * smpl, + llama_token_data_array * candidates); // // Model split @@ -1169,8 +1171,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx); - LLAMA_API void llama_reset_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 77d748144..a41e8f71d 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -438,7 +438,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -446,7 +446,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 }; + llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 40d82af73..2c313737d 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -11,6 +11,11 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + mutable int64_t t_total_us; + + mutable int32_t n_sample; + mutable int32_t n_accept; }; struct llama_grammar * llama_get_grammar(struct llama_context * ctx); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 670d00420..9efdcfcb0 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -21,7 +21,15 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { +struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab) { + return new llama_sampling(n_vocab); +} + +void llama_sampling_free_impl(struct llama_sampling * sampling) { + delete sampling; +} + +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } @@ -29,7 +37,7 @@ void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { smpl.rng.seed(seed); } -void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { +void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order @@ -54,7 +62,7 @@ void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_dat } } -void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; @@ -129,12 +137,12 @@ void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_ candidates->size = k; } -void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -155,7 +163,7 @@ void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_arra candidates->size = last_idx; } -void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -210,12 +218,12 @@ void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_ } } -void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -264,7 +272,7 @@ void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_ candidates->size = last_idx; } -void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -272,7 +280,7 @@ void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_ar } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -322,7 +330,7 @@ void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_ar candidates->sorted = false; } -void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { +void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -331,7 +339,7 @@ void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_ar // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -383,13 +391,13 @@ void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_ar #endif } -void llama_sample_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { +void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } } -void llama_sample_repetition_penalties_impl( +void llama_sampling_repetition_penalties_impl( struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, const llama_token * last_tokens, @@ -430,7 +438,7 @@ void llama_sample_repetition_penalties_impl( candidates->sorted = false; } -void llama_sample_apply_guidance_impl( +void llama_sampling_apply_guidance_impl( struct llama_sampling & smpl, float * logits, float * logits_guidance, @@ -448,10 +456,10 @@ void llama_sample_apply_guidance_impl( } } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { +llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { const int32_t n_vocab = float(smpl.n_vocab); - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -470,8 +478,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k_impl(smpl, candidates, int(k), 1); - llama_token X = llama_sample_token_impl(smpl, candidates); + llama_sampling_top_k_impl(smpl, candidates, int(k), 1); + llama_token X = llama_sampling_sample_impl(smpl, candidates); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -486,8 +494,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama return X; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - llama_sample_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { + llama_sampling_softmax_impl(smpl, candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -499,10 +507,10 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, ll } // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); + llama_token X = llama_sampling_sample_impl(smpl, candidates); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -517,7 +525,7 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, ll return X; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { +llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; @@ -528,8 +536,8 @@ llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, lla return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - llama_sample_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sampling_softmax_impl(smpl, candidates); std::vector probs; probs.reserve(candidates->size); @@ -545,6 +553,6 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama return result; } -llama_token llama_sample_token_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl.rng); +llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { + return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng); } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 9f4d0c63d..92f9d06bb 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -3,29 +3,37 @@ #include "llama-impl.h" struct llama_sampling { - llama_sampling(uint32_t seed, int32_t n_vocab) : rng(seed), n_vocab(n_vocab) {} + llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + + const int32_t n_vocab; std::mt19937 rng; - const int32_t n_vocab; + mutable int64_t t_total_us = 0; + + mutable int32_t n_sample = 0; }; // // internal API // -void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); +struct llama_sampling * llama_sampling_init_impl(int32_t n_vocab); -void llama_sample_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp); +void llama_sampling_free_impl(struct llama_sampling * sampling); -void llama_sample_repetition_penalties_impl( +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); + +void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sampling_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp); + +void llama_sampling_repetition_penalties_impl( struct llama_sampling & smpl, llama_token_data_array * candidates, const llama_token * last_tokens, @@ -34,15 +42,15 @@ void llama_sample_repetition_penalties_impl( float penalty_freq, float penalty_present); -void llama_sample_apply_guidance_impl( +void llama_sampling_apply_guidance_impl( struct llama_sampling & smpl, float * logits, float * logits_guidance, float scale); -llama_token llama_sample_token_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +llama_token llama_sampling_sample_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); +llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); diff --git a/src/llama.cpp b/src/llama.cpp index 1450ca23a..6e3358231 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2669,6 +2669,7 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model) : model(model) + , sampling(llama_n_vocab(&model)) , grammar() , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -2686,12 +2687,11 @@ struct llama_context { const struct llama_model & model; struct llama_cparams cparams; + struct llama_sampling sampling; struct llama_grammar grammar; struct llama_kv_cache kv_self; struct llama_control_vector cvec; - std::vector sampling; // sampling context for each sequence - std::unordered_map lora_adapters; std::vector backends; @@ -2707,14 +2707,12 @@ struct llama_context { mutable int64_t t_start_us; mutable int64_t t_load_us; - mutable int64_t t_sample_us = 0; mutable int64_t t_p_eval_us = 0; mutable int64_t t_eval_us = 0; mutable int64_t t_compute_start_us = 0; mutable int64_t n_queued_tokens = 0; - mutable int32_t n_sample = 0; mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_eval = 0; // number of eval calls @@ -16542,10 +16540,7 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.reserve(cparams.n_seq_max); - for (uint32_t i = 0; i < cparams.n_seq_max; ++i) { - ctx->sampling.emplace_back(params.seed, llama_n_vocab(model)); - } + llama_sampling_set_rng_seed_impl(ctx->sampling, params.seed); ctx->logits_all = params.logits_all; @@ -16827,6 +16822,10 @@ const struct llama_model * llama_get_model(const struct llama_context * ctx) { return &ctx->model; } +struct llama_sampling * llama_get_sampling(struct llama_context * ctx) { + return &ctx->sampling; +} + const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) { return &ctx->model.vocab; } @@ -17322,7 +17321,6 @@ size_t llama_state_get_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. - const size_t s_n_rng = sizeof(uint32_t); const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_n_outputs = sizeof(size_t); @@ -17342,8 +17340,8 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_total = ( - + s_n_rng - + cparams.n_seq_max*(s_rng_size + s_rng) + + s_rng_size + + s_rng + s_n_outputs + s_output_pos + s_logits_size @@ -17360,7 +17358,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { ); // on session change it is very likely that the state size has changed - so we need to update this function - static_assert(LLAMA_SESSION_VERSION == 8, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); return s_total; } @@ -17423,22 +17421,16 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data // copy rngs { - const uint32_t n_rng = ctx->sampling.size(); + std::ostringstream rng_ss; + rng_ss << ctx->sampling.rng; - data_ctx->write(&n_rng, sizeof(n_rng)); + const std::string & rng_str = rng_ss.str(); + const size_t rng_size = rng_str.size(); - for (const auto & smpl : ctx->sampling) { - std::ostringstream rng_ss; - rng_ss << smpl.rng; + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - const std::string & rng_str = rng_ss.str(); - const size_t rng_size = rng_str.size(); - - GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - - data_ctx->write(&rng_size, sizeof(rng_size)); - data_ctx->write(rng_str.data(), rng_size); - } + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(rng_str.data(), rng_size); } // copy outputs @@ -17588,24 +17580,17 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { // set rngs { - uint32_t n_rng; - memcpy(&n_rng, inp, sizeof(n_rng)); inp += sizeof(n_rng); + size_t rng_size; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); - GGML_ASSERT(n_rng == ctx->cparams.n_seq_max); + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - for (auto & smpl : ctx->sampling) { - size_t rng_size; - memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + std::string rng_str((const char *)inp, rng_size); inp += rng_size; - GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); + std::istringstream rng_ss(rng_str); + rng_ss >> ctx->sampling.rng; - std::string rng_str((const char *)inp, rng_size); inp += rng_size; - - std::istringstream rng_ss(rng_str); - rng_ss >> smpl.rng; - - GGML_ASSERT(!rng_ss.fail()); - } + GGML_ASSERT(!rng_ss.fail()); } // set output ids @@ -18978,11 +18963,14 @@ void llama_grammar_sample( const struct llama_grammar * grammar, const struct llama_context * ctx, llama_token_data_array * candidates) { - time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling + time_meas tm(grammar->t_total_us); // TODO: measure grammar time separately from sampling llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates); + + grammar->n_sample++; } +// deprecated void llama_sample_grammar( struct llama_context * ctx, llama_token_data_array * candidates, @@ -18994,140 +18982,148 @@ void llama_grammar_accept_token( struct llama_grammar * grammar, struct llama_context * ctx, llama_token token) { - time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling + time_meas tm(grammar->t_total_us); // TODO: measure grammar time separately from sampling llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token); + + grammar->n_accept++; } // // sampling // -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(ctx->sampling[0], seed); +struct llama_sampling * llama_sampling_init(int32_t n_vocab) { + return llama_sampling_init_impl(n_vocab); } -void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id seq_id) { - llama_set_rng_seed_impl(ctx->sampling[seq_id], seed); +void llama_sampling_free(struct llama_sampling * smpl) { + if (smpl == nullptr) { + return; + } + + llama_sampling_free_impl(smpl); } -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - time_meas tm(ctx->t_sample_us); - - llama_sample_softmax_impl(ctx->sampling[0], candidates); +void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) { + llama_sampling_set_rng_seed_impl(*smpl, seed); } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); - llama_sample_top_k_impl(ctx->sampling[0], candidates, k, min_keep); + llama_sampling_softmax_impl(*smpl, candidates); } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + time_meas tm(smpl->t_total_us); - llama_sample_top_p_impl(ctx->sampling[0], candidates, p, min_keep); + llama_sampling_top_k_impl(*smpl, candidates, k, min_keep); } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); - llama_sample_min_p_impl(ctx->sampling[0], candidates, p, min_keep); + llama_sampling_top_p_impl(*smpl, candidates, p, min_keep); } -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); - llama_sample_tail_free_impl(ctx->sampling[0], candidates, z, min_keep); + llama_sampling_min_p_impl(*smpl, candidates, p, min_keep); } -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { + time_meas tm(smpl->t_total_us); - llama_sample_typical_impl(ctx->sampling[0], candidates, p, min_keep); + llama_sampling_tail_free_impl(*smpl, candidates, z, min_keep); } -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); - llama_sample_entropy_impl(ctx->sampling[0], candidates_p, min_temp, max_temp, exponent_val); + llama_sampling_typical_impl(*smpl, candidates, p, min_keep); } -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - time_meas tm(ctx->t_sample_us); +void llama_sampling_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { + time_meas tm(smpl->t_total_us); - llama_sample_temp_impl(ctx->sampling[0], candidates_p, temp); + llama_sampling_entropy_impl(*smpl, candidates_p, min_temp, max_temp, exponent_val); } -void llama_sample_repetition_penalties( - struct llama_context * ctx, +void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float temp) { + time_meas tm(smpl->t_total_us); + + llama_sampling_temp_impl(*smpl, candidates_p, temp); +} + +void llama_sampling_repetition_penalties( + struct llama_sampling * smpl, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present) { - time_meas tm(ctx->t_sample_us); + time_meas tm(smpl->t_total_us); - llama_sample_repetition_penalties_impl(ctx->sampling[0], candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); } -void llama_sample_apply_guidance( - struct llama_context * ctx, +void llama_sampling_apply_guidance( + struct llama_sampling * smpl, float * logits, float * logits_guidance, float scale) { - time_meas tm(ctx->t_sample_us); + time_meas tm(smpl->t_total_us); - llama_sample_apply_guidance_impl(ctx->sampling[0], logits, logits_guidance, scale); + llama_sampling_apply_guidance_impl(*smpl, logits, logits_guidance, scale); } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - time_meas tm(ctx->t_sample_us); +llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + time_meas tm(smpl->t_total_us); - auto res = llama_sample_token_mirostat_impl(ctx->sampling[0], candidates, tau, eta, m, mu); + auto res = llama_sampling_sample_mirostat_impl(*smpl, candidates, tau, eta, m, mu); - ctx->n_sample++; + smpl->n_sample++; return res; } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - time_meas tm(ctx->t_sample_us); +llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { + time_meas tm(smpl->t_total_us); - auto res = llama_sample_token_mirostat_v2_impl(ctx->sampling[0], candidates, tau, eta, mu); + auto res = llama_sampling_sample_mirostat_v2_impl(*smpl, candidates, tau, eta, mu); - ctx->n_sample++; + smpl->n_sample++; return res; } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - time_meas tm(ctx->t_sample_us); +llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); - auto res = llama_sample_token_greedy_impl(ctx->sampling[0], candidates); + auto res = llama_sampling_sample_greedy_impl(*smpl, candidates); - ctx->n_sample++; + smpl->n_sample++; return res; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_seq(ctx, candidates, 0); -} +llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); -llama_token llama_sample_token_seq(struct llama_context * ctx, llama_token_data_array * candidates, llama_seq_id seq_id) { - GGML_ASSERT(seq_id >= 0 && seq_id < (int32_t) ctx->cparams.n_seq_max); + auto res = llama_sampling_sample_impl(*smpl, candidates); - time_meas tm(ctx->t_sample_us); - - auto res = llama_sample_token_impl(ctx->sampling[seq_id], candidates); - - ctx->n_sample++; + smpl->n_sample++; return res; } +// +// model split +// + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { @@ -19152,30 +19148,29 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -struct llama_timings llama_get_timings(struct llama_context * ctx) { - struct llama_timings result = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, +void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar) { + const llama_timings timings = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_total_us : ctx->sampling.t_total_us), + /*.t_grammar_ms =*/ grammar ? (1e-3 * grammar->t_total_us) : 0.0, + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sample =*/ std::max(1, ctx->n_sample), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), + /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : ctx->sampling.n_sample), + /*.n_grammar_sample =*/ std::max(0, grammar ? grammar->n_sample : 0), + /*.n_grammar_accept =*/ std::max(0, grammar ? grammar->n_accept : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; - return result; -} - -void llama_print_timings(struct llama_context * ctx) { - const llama_timings timings = llama_get_timings(ctx); - LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_grammar_ms, timings.n_grammar_sample, timings.t_grammar_ms / timings.n_grammar_sample, 1e3 / timings.t_grammar_ms * timings.n_grammar_sample); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -19183,11 +19178,18 @@ void llama_print_timings(struct llama_context * ctx) { LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl, struct llama_grammar * grammar) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; - ctx->t_sample_us = ctx->n_sample = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; + + if (smpl) { + smpl->t_total_us = smpl->n_sample = 0; + } + + if (grammar) { + grammar->t_total_us = grammar->n_sample = grammar->n_accept = 0; + } } const char * llama_print_system_info(void) { @@ -19233,21 +19235,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e-3 * ctx->t_eval_us / ctx->n_eval); fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->t_sample_us / ctx->n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->n_sample / ctx->t_sample_us); } // For internal test use diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 7284ad9a3..6e4c2e314 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,5 @@ #include "ggml.h" -#include "llama-sampling.h" +#include "llama.h" #ifdef NDEBUG #undef NDEBUG @@ -20,7 +20,7 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -30,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -53,9 +53,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -77,7 +77,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free_impl(smpl, &candidates_p, z, 1); + llama_sampling_tail_free(smpl, &candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -88,7 +88,7 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -99,9 +99,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -122,7 +122,7 @@ static void test_typical(const std::vector & probs, const std::vector candidates; candidates.reserve(n_vocab); @@ -148,10 +148,10 @@ static void test_repetition_penalties( } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax_impl(smpl, &candidates_p); + llama_sampling_softmax(smpl, &candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sample_softmax_impl(smpl, &candidates_p); + llama_sampling_repetition_penalties(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); + llama_sampling_softmax(smpl, &candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -162,7 +162,7 @@ static void test_repetition_penalties( static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { - llama_sampling smpl(1234, n_vocab); + llama_sampling * smpl = llama_sampling_init(n_vocab); std::vector candidates; candidates.reserve(n_vocab); @@ -178,16 +178,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sample_top_k_impl(smpl, &candidates_p, top_k, 1); break; + case 'k': llama_sampling_top_k(smpl, &candidates_p, top_k, 1); break; case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; case 'y': GGML_ASSERT(false && "typical test not implemented"); break; - case 'p': llama_sample_top_p_impl(smpl, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p_impl(smpl, &candidates_p, min_p, 1); break; + case 'p': llama_sampling_top_p(smpl, &candidates_p, top_p, 1); break; + case 'm': llama_sampling_min_p(smpl, &candidates_p, min_p, 1); break; case 't': GGML_ASSERT(false && "temperature test not implemented"); break; default : GGML_ASSERT(false && "Unknown sampler"); break; } - llama_sample_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests + llama_sampling_softmax(smpl, &candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size;