From 0e6db6fec1fd759b799a511018d6b21664cfb76d Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 03:52:45 -0600 Subject: [PATCH] Fix mirostat state when using multiple sequences --- common/common.cpp | 21 +++++++++++++++------ common/common.h | 16 ++++++++++++++-- examples/parallel/parallel.cpp | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0f55c33a7..752785f3a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -940,10 +940,11 @@ llama_token llama_sample_token( struct llama_context * ctx, struct llama_context * ctx_guidance, struct llama_grammar * grammar, - const struct gpt_params & params, + struct gpt_params & params, const std::vector & last_tokens, std::vector & candidates, - int idx) { + const int idx, + llama_seq_id seq) { const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -1011,15 +1012,23 @@ llama_token llama_sample_token( // Greedy sampling id = llama_sample_token_greedy(ctx, &cur_p); } else { + float * mirostat_mu = NULL; + if (mirostat > 0) { + seq = std::max(0, seq); // Deal with people passing -1 or something. + auto mu_it = params.sampler_state.find(seq); + if (mu_it == params.sampler_state.end()) { + const llama_sampler_state new_state = { 2.0f * mirostat_tau }; + mu_it = params.sampler_state.insert({seq, new_state}).first; + } + mirostat_mu = &mu_it->second.mirostat_mu; + } if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu); } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); diff --git a/common/common.h b/common/common.h index c80215279..3fa77e877 100644 --- a/common/common.h +++ b/common/common.h @@ -33,6 +33,10 @@ // int32_t get_num_physical_cores(); +typedef struct llama_sampler_state { + float mirostat_mu; // mirostat sampler state +} llama_sampler_state; + struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); @@ -54,6 +58,9 @@ struct gpt_params { float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + // per sequence sampler state + std::unordered_map sampler_state; + // sampling parameters int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled @@ -186,6 +193,9 @@ std::string llama_detokenize_bpe( // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function +// Note: When using multiple sequences, it is the caller's responsibility to delete +// the item in params.sampler_state when a sequence ends and samplers that rely on +// state are being used. // // required: // - ctx: context to use for sampling @@ -196,6 +206,7 @@ std::string llama_detokenize_bpe( // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) +// - seq: sequence id to associate sampler state with (currently only used by mirostat) // // returns: // - token: sampled token @@ -205,10 +216,11 @@ llama_token llama_sample_token( struct llama_context * ctx, struct llama_context * ctx_guidance, struct llama_grammar * grammar, - const struct gpt_params & params, + struct gpt_params & params, const std::vector & last_tokens, std::vector & candidates, - int idx = 0); + const int idx = 0, + llama_seq_id seq = 0); // // YAML utils diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 721888da7..8806cf724 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -339,7 +339,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -384,7 +384,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - + params.sampler_state.erase(client.seq_id); client.seq_id = -1; }