diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e40516..f1f41daab 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,11 +1,13 @@ -#define LLAMA_API_INTERNAL #include "sampling.h" + #include -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) { struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; + result->seq_id = seq_id; + result->ctx = ctx; result->grammar = nullptr; // if there is a grammar, parse it @@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s if (seed == LLAMA_DEFAULT_SEED) { seed = std::random_device{}(); } - ctx->rng.seed(seed); + llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id); } void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { @@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl( bool is_resampling) { const llama_sampling_params & params = ctx_sampling->params; - const float temp = params.temp; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; + const float temp = params.temp; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; std::vector original_logits; auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); @@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); + id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id); //{ // const int n_top = 10; diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8b..f03eb3be3 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -70,9 +70,12 @@ struct llama_sampling_context { // parameters that will be used for sampling llama_sampling_params params; + llama_seq_id seq_id; + // mirostat sampler state float mirostat_mu; + llama_context * ctx; // TMP llama_grammar * grammar; // internal @@ -81,15 +84,14 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; - size_t n_valid; // Number of correct top tokens with correct probabilities. - std::mt19937 rng; + size_t n_valid; // Number of correct top tokens with correct probabilities. }; #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id); void llama_sampling_free(struct llama_sampling_context * ctx); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 48a705e15..d1fbd8671 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,8 +1,7 @@ -#define LLAMA_API_INTERNAL - #include "grammar-parser.h" #include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include "unicode.h" #include diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index dc93d2301..91c777aad 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -346,7 +346,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); while (n_remain != 0 || params.interactive) { // predict diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8c7dd2ae3..14350bdf7 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index fb20ad93f..b8cddb660 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); // verification n-grams std::vector ngrams_cur(G); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index bb571bac4..3ccc6dfb4 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0); std::vector draft; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a0d817b1a..1819df198 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -527,7 +527,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7faeaec97..d08e07ca2 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i); } std::vector tokens_system; diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 68cf8d359..25c2de60c 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include #include diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7813a2957..3b5ed6d4d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1090,7 +1090,7 @@ struct server_context { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); + slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0939a1a6a..569d95522 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -175,7 +175,7 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0); // draft sequence data std::vector drafts(n_seq_dft); @@ -186,7 +186,7 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); diff --git a/include/llama.h b/include/llama.h index e68cd807e..adc1b72cc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 7 +#define LLAMA_SESSION_VERSION 8 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 @@ -1031,6 +1031,9 @@ extern "C" { // Sets the current rng seed. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id), + "temporary API, until llama_sampling_context is implemented, do not use"); + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_repetition_penalties( @@ -1137,11 +1140,18 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. + /// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx. LLAMA_API llama_token llama_sample_token( struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id]. + LLAMA_API DEPRECATED(llama_token llama_sample_token_seq( + struct llama_context * ctx, + llama_token_data_array * candidates, + llama_seq_id seq_id), + "temporary API, until llama_sampling_context is implemented, do not use"); + // // Model split // @@ -1175,59 +1185,4 @@ extern "C" { } #endif -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL - -#include -#include -#include - -struct ggml_tensor; - -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); - -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - -// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. -// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); - -#endif // LLAMA_API_INTERNAL - #endif // LLAMA_H diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bd9322e2f..77d748144 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -445,15 +445,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { + llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { result->stacks[is][ie] = &result->rules[ir0][ir1]; } } @@ -464,14 +464,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(grammar); - GGML_ASSERT(vocab); - - int64_t t_start_sample_us = ggml_time_us(); - +void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) { bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -486,33 +481,29 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = vocab.cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (llama_token_is_eog_impl(vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); - - if (llama_token_is_eog_impl(*vocab, token)) { - for (const auto & stack : grammar->stacks) { +void llama_grammar_accept_token_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) { + if (llama_token_is_eog_impl(vocab, token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -520,20 +511,18 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ASSERT(false); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = vocab.cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; llama_grammar_stacks tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - grammar->stacks = tmp_new_stacks; + llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks); + grammar.stacks = tmp_new_stacks; } - grammar->partial_utf8 = decoded.second; - GGML_ASSERT(!grammar->stacks.empty()); - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + grammar.partial_utf8 = decoded.second; + GGML_ASSERT(!grammar.stacks.empty()); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 8e578e09f..40d82af73 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -26,16 +26,14 @@ struct llama_grammar * llama_grammar_init_impl( void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar); void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, + const struct llama_grammar & grammar, + const struct llama_vocab & vocab, llama_token_data_array * candidates); void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, + struct llama_grammar & grammar, + const struct llama_vocab & vocab, llama_token token); diff --git a/src/llama-impl.h b/src/llama-impl.h index dcc8c1c15..c18967752 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,8 +1,11 @@ #pragma once -#define LLAMA_API_INTERNAL #include "llama.h" +#include +#include +#include + #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -24,3 +27,43 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr, + llama_grammar_stacks & new_stacks); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d65..670d00420 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -21,19 +21,17 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { +void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } - smpl->rng.seed(seed); + smpl.rng.seed(seed); } -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { +void llama_sample_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); - const int64_t t_start_sample_us = ggml_time_us(); - // Sort the logits in descending order if (!candidates->sorted) { std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { @@ -44,28 +42,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar float max_l = candidates->data[0].logit; float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { float p = expf(candidates->data[i].logit - max_l); candidates->data[i].p = p; cum_sum += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sample_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; // } - const int64_t t_start_sample_us = ggml_time_us(); - if (k <= 0) { k = candidates->size; } @@ -133,21 +127,15 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra candidates->sorted = true; } candidates->size = k; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } llama_sample_softmax_impl(smpl, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - // Compute the cumulative probabilities float cum_sum = 0.0f; size_t last_idx = candidates->size; @@ -165,19 +153,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - bool min_p_applied = false; // if the candidates aren't sorted, try the unsorted implementation first @@ -226,19 +208,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the matching tokens candidates->size = i; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax_impl(smpl, candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -285,13 +262,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_ // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -299,9 +272,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax_impl(smpl, candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -349,15 +320,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); candidates->size = new_candidates.size(); candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sample_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -366,7 +331,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sample_softmax_impl(smpl, candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -416,38 +381,26 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); } #endif - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sample_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, + struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { + float penalty_repeat, + float penalty_freq, + float penalty_present) { if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; for (size_t i = 0; i < penalty_last_n; ++i) { @@ -475,21 +428,14 @@ void llama_sample_repetition_penalties_impl( } candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, + struct llama_sampling & smpl, float * logits, float * logits_guidance, float scale) { - GGML_ASSERT(smpl); - - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = smpl->n_vocab; + const auto n_vocab = smpl.n_vocab; llama_log_softmax(logits, n_vocab); llama_log_softmax(logits_guidance, n_vocab); @@ -500,18 +446,12 @@ void llama_sample_apply_guidance_impl( l = scale * (l - g) + g; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(smpl); +llama_token llama_sample_token_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + const int32_t n_vocab = float(smpl.n_vocab); - const int32_t n_vocab = float(smpl->n_vocab); - - int64_t t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sample_softmax_impl(smpl, candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -530,10 +470,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_sample_top_k_impl(smpl, candidates, int(k), 1); llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -545,14 +483,10 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Update mu using the learning rate and error *mu = *mu - eta * e; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; return X; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - +llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { llama_sample_softmax_impl(smpl, candidates); // Truncate the words with surprise values greater than mu @@ -564,16 +498,11 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll candidates->size = 1; } - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } - // Normalize the probabilities of the remaining words llama_sample_softmax_impl(smpl, candidates); // Sample the next word X from the remaining words llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -585,33 +514,22 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll // Update mu using the learning rate and error *mu = *mu - eta * e; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } return X; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - +llama_token llama_sample_token_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } + return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(smpl); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sample_token_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sample_softmax_impl(smpl, candidates); std::vector probs; probs.reserve(candidates->size); @@ -624,12 +542,9 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama llama_token result = candidates->data[idx].id; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +llama_token llama_sample_token_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { + return llama_sample_token_with_rng_impl(smpl, candidates, smpl.rng); } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef7..9f4d0c63d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -3,38 +3,30 @@ #include "llama-impl.h" struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + llama_sampling(uint32_t seed, int32_t n_vocab) : rng(seed), n_vocab(n_vocab) {} std::mt19937 rng; - int32_t n_vocab = 0; - - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; - - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } + const int32_t n_vocab; }; // // internal API // -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); +void llama_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); +void llama_sample_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +void llama_sample_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sample_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sample_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sample_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp); void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, + struct llama_sampling & smpl, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, @@ -43,14 +35,14 @@ void llama_sample_repetition_penalties_impl( float penalty_present); void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, + struct llama_sampling & smpl, float * logits, float * logits_guidance, float scale); -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +llama_token llama_sample_token_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); +llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); +llama_token llama_sample_token_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +llama_token llama_sample_token_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sample_token_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); diff --git a/src/llama.cpp b/src/llama.cpp index 40c5e8e8d..1450ca23a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -163,6 +163,19 @@ static void zeros(std::ofstream & file, size_t n) { } } +struct time_meas { + time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + t_acc += ggml_time_us() - t_start_us; + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + + LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -2656,7 +2669,6 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model) : model(model) - , sampling(llama_n_vocab(&model)) , grammar() , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -2674,11 +2686,12 @@ struct llama_context { const struct llama_model & model; struct llama_cparams cparams; - struct llama_sampling sampling; struct llama_grammar grammar; struct llama_kv_cache kv_self; struct llama_control_vector cvec; + std::vector sampling; // sampling context for each sequence + std::unordered_map lora_adapters; std::vector backends; @@ -2692,16 +2705,18 @@ struct llama_context { bool has_evaluated_once = false; - int64_t t_start_us; - int64_t t_load_us; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_sample_us = 0; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; - int64_t t_compute_start_us = 0; - int64_t n_queued_tokens = 0; + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls + mutable int32_t n_sample = 0; + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls // host buffer for the model output (logits and embeddings) ggml_backend_buffer_t buf_output = nullptr; @@ -16527,8 +16542,12 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + ctx->sampling.reserve(cparams.n_seq_max); + for (uint32_t i = 0; i < cparams.n_seq_max; ++i) { + ctx->sampling.emplace_back(params.seed, llama_n_vocab(model)); + } + + ctx->logits_all = params.logits_all; uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -17303,6 +17322,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_n_rng = sizeof(uint32_t); const size_t s_rng_size = sizeof(size_t); const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_n_outputs = sizeof(size_t); @@ -17322,8 +17342,8 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; const size_t s_total = ( - + s_rng_size - + s_rng + + s_n_rng + + cparams.n_seq_max*(s_rng_size + s_rng) + s_n_outputs + s_output_pos + s_logits_size @@ -17340,7 +17360,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { ); // on session change it is very likely that the state size has changed - so we need to update this function - static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + static_assert(LLAMA_SESSION_VERSION == 8, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); return s_total; } @@ -17401,18 +17421,24 @@ struct llama_data_file_context : llama_data_context { static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { llama_synchronize(ctx); - // copy rng + // copy rngs { - std::ostringstream rng_ss; - rng_ss << ctx->sampling.rng; + const uint32_t n_rng = ctx->sampling.size(); - const std::string & rng_str = rng_ss.str(); - const size_t rng_size = rng_str.size(); + data_ctx->write(&n_rng, sizeof(n_rng)); - GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); + for (const auto & smpl : ctx->sampling) { + std::ostringstream rng_ss; + rng_ss << smpl.rng; - data_ctx->write(&rng_size, sizeof(rng_size)); - data_ctx->write(rng_str.data(), rng_size); + const std::string & rng_str = rng_ss.str(); + const size_t rng_size = rng_str.size(); + + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); + + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(rng_str.data(), rng_size); + } } // copy outputs @@ -17560,19 +17586,26 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const uint8_t * inp = src; - // set rng + // set rngs { - size_t rng_size; - memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + uint32_t n_rng; + memcpy(&n_rng, inp, sizeof(n_rng)); inp += sizeof(n_rng); - GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); + GGML_ASSERT(n_rng == ctx->cparams.n_seq_max); - std::string rng_str((const char *)inp, rng_size); inp += rng_size; + for (auto & smpl : ctx->sampling) { + size_t rng_size; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); - std::istringstream rng_ss(rng_str); - rng_ss >> ctx->sampling.rng; + GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE); - GGML_ASSERT(!rng_ss.fail()); + std::string rng_str((const char *)inp, rng_size); inp += rng_size; + + std::istringstream rng_ss(rng_str); + rng_ss >> smpl.rng; + + GGML_ASSERT(!rng_ss.fail()); + } } // set output ids @@ -18930,18 +18963,24 @@ struct llama_grammar * llama_grammar_init( } void llama_grammar_free(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + llama_grammar_free_impl(grammar); } struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); + return llama_grammar_copy_impl(*grammar); } void llama_grammar_sample( const struct llama_grammar * grammar, const struct llama_context * ctx, llama_token_data_array * candidates) { - llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); + time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling + + llama_grammar_sample_impl(*grammar, ctx->model.vocab, candidates); } void llama_sample_grammar( @@ -18955,7 +18994,9 @@ void llama_grammar_accept_token( struct llama_grammar * grammar, struct llama_context * ctx, llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); + time_meas tm(ctx->t_sample_us); // TODO: measure grammar time separately from sampling + + llama_grammar_accept_token_impl(*grammar, ctx->model.vocab, token); } // @@ -18963,39 +19004,59 @@ void llama_grammar_accept_token( // void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(&ctx->sampling, seed); + llama_set_rng_seed_impl(ctx->sampling[0], seed); +} + +void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id seq_id) { + llama_set_rng_seed_impl(ctx->sampling[seq_id], seed); } void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); + time_meas tm(ctx->t_sample_us); + + llama_sample_softmax_impl(ctx->sampling[0], candidates); } void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); + time_meas tm(ctx->t_sample_us); + + llama_sample_top_k_impl(ctx->sampling[0], candidates, k, min_keep); } void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); + time_meas tm(ctx->t_sample_us); + + llama_sample_top_p_impl(ctx->sampling[0], candidates, p, min_keep); } void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); + time_meas tm(ctx->t_sample_us); + + llama_sample_min_p_impl(ctx->sampling[0], candidates, p, min_keep); } void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); + time_meas tm(ctx->t_sample_us); + + llama_sample_tail_free_impl(ctx->sampling[0], candidates, z, min_keep); } void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); + time_meas tm(ctx->t_sample_us); + + llama_sample_typical_impl(ctx->sampling[0], candidates, p, min_keep); } void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); + time_meas tm(ctx->t_sample_us); + + llama_sample_entropy_impl(ctx->sampling[0], candidates_p, min_temp, max_temp, exponent_val); } void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); + time_meas tm(ctx->t_sample_us); + + llama_sample_temp_impl(ctx->sampling[0], candidates_p, temp); } void llama_sample_repetition_penalties( @@ -19006,7 +19067,9 @@ void llama_sample_repetition_penalties( float penalty_repeat, float penalty_freq, float penalty_present) { - llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + time_meas tm(ctx->t_sample_us); + + llama_sample_repetition_penalties_impl(ctx->sampling[0], candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); } void llama_sample_apply_guidance( @@ -19014,27 +19077,55 @@ void llama_sample_apply_guidance( float * logits, float * logits_guidance, float scale) { - llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); + time_meas tm(ctx->t_sample_us); + + llama_sample_apply_guidance_impl(ctx->sampling[0], logits, logits_guidance, scale); } llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); + time_meas tm(ctx->t_sample_us); + + auto res = llama_sample_token_mirostat_impl(ctx->sampling[0], candidates, tau, eta, m, mu); + + ctx->n_sample++; + + return res; } llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); + time_meas tm(ctx->t_sample_us); + + auto res = llama_sample_token_mirostat_v2_impl(ctx->sampling[0], candidates, tau, eta, mu); + + ctx->n_sample++; + + return res; } llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); -} + time_meas tm(ctx->t_sample_us); -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng); + auto res = llama_sample_token_greedy_impl(ctx->sampling[0], candidates); + + ctx->n_sample++; + + return res; } llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); + return llama_sample_token_seq(ctx, candidates, 0); +} + +llama_token llama_sample_token_seq(struct llama_context * ctx, llama_token_data_array * candidates, llama_seq_id seq_id) { + GGML_ASSERT(seq_id >= 0 && seq_id < (int32_t) ctx->cparams.n_seq_max); + + time_meas tm(ctx->t_sample_us); + + auto res = llama_sample_token_impl(ctx->sampling[seq_id], candidates); + + ctx->n_sample++; + + return res; } int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { @@ -19066,11 +19157,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, + /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), + /*.n_sample =*/ std::max(1, ctx->n_sample), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -19095,9 +19186,8 @@ void llama_print_timings(struct llama_context * ctx) { void llama_reset_timings(struct llama_context * ctx) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; + ctx->t_sample_us = ctx->n_sample = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - - ctx->sampling.reset_timings(); } const char * llama_print_system_info(void) { @@ -19144,20 +19234,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); + 1.0e-3 * ctx->t_sample_us / ctx->n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); + fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); + fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); + 1.0e6 * ctx->n_sample / ctx->t_sample_us); } // For internal test use diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 68f971bfe..8786f05e0 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,13 +2,13 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL - #include "ggml.h" #include "llama.h" +#include "llama-impl.h" +#include "unicode.h" #include "grammar-parser.h" #include "json-schema-to-grammar.h" -#include "unicode.h" + #include #include #include diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b3..aa5f2012b 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,8 +2,8 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL #include "llama.h" +#include "llama-impl.h" #include "grammar-parser.h" #include diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6374958fe..7284ad9a3 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,5 @@ #include "ggml.h" -#include "llama.h" +#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -20,6 +20,8 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -28,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -49,9 +53,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -71,7 +77,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free(nullptr, &candidates_p, z, 1); + llama_sample_tail_free_impl(smpl, &candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -82,6 +88,8 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -91,9 +99,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -112,7 +122,7 @@ static void test_typical(const std::vector & probs, const std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -136,10 +148,10 @@ static void test_repetition_penalties( } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax(nullptr, &candidates_p); + llama_sample_softmax_impl(smpl, &candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sample_softmax(nullptr, &candidates_p); + llama_sample_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); + llama_sample_softmax_impl(smpl, &candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -148,9 +160,10 @@ static void test_repetition_penalties( } } -static void test_sampler_queue( - const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p +static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { + llama_sampling smpl(1234, n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -165,16 +178,16 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; + case 'k': llama_sample_top_k_impl(smpl, &candidates_p, top_k, 1); break; case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break; case 'y': GGML_ASSERT(false && "typical test not implemented"); break; - case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; + case 'p': llama_sample_top_p_impl(smpl, &candidates_p, top_p, 1); break; + case 'm': llama_sample_min_p_impl(smpl, &candidates_p, min_p, 1); break; case 't': GGML_ASSERT(false && "temperature test not implemented"); break; default : GGML_ASSERT(false && "Unknown sampler"); break; } - llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests + llama_sample_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size;