diff --git a/expose.h b/expose.h index 89471f8c2..3d3677b0d 100644 --- a/expose.h +++ b/expose.h @@ -4,13 +4,13 @@ const int stop_token_max = 10; // match kobold's sampler list and order enum samplers { - KCPP_SAMPLER_TOP_K, - KCPP_SAMPLER_TOP_A, - KCPP_SAMPLER_TOP_P, - KCPP_SAMPLER_TFS, - KCPP_SAMPLER_TYP, - KCPP_SAMPLER_TEMP, - KCPP_SAMPLER_REP_PEN, + KCPP_SAMPLER_TOP_K=0, + KCPP_SAMPLER_TOP_A=1, + KCPP_SAMPLER_TOP_P=2, + KCPP_SAMPLER_TFS=3, + KCPP_SAMPLER_TYP=4, + KCPP_SAMPLER_TEMP=5, + KCPP_SAMPLER_REP_PEN=6, KCPP_SAMPLER_MAX }; struct load_model_inputs diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 396e53dd5..cdbbf3b73 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -219,16 +219,31 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) candidates->size = last_idx; } -void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p) +void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p) { auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); - llama_sample_repetition_penalty(nullptr, &candidates_p, + llama_sample_repetition_penalty(nullptr, candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, rep_pen); } +void sample_temperature(llama_token_data_array * candidates_p, float temp) +{ + if (temp <= 0) + { + // Imitate greedy sampling + temp = 0.01f; //cannot be zero else div0 + llama_sample_temperature(nullptr, candidates_p, temp); + llama_sample_top_k(nullptr, candidates_p, 1, 1); //only want first candidate + } + else + { + llama_sample_temperature(nullptr, candidates_p, temp); + } +} + int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, -int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX]) +int mirostat, float mirostat_tau, float mirostat_eta, const std::vector & sampler_order) { int id = 0; std::vector candidates; @@ -239,79 +254,55 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // Run this except for when we are going to do the sampler reordering case below - if (temp <= 0 || mirostat > 0 || sampler_len == 0) - { - apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p); - } - - // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, - // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - // last_n_repeat, alpha_frequency, alpha_presence); - - if (temp <= 0) - { - // Greedy sampling - id = llama_sample_token_greedy(nullptr, &candidates_p); - } - else + if (mirostat == 1 || mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p); + sample_temperature(&candidates_p, temp); if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(nullptr, &candidates_p, temp); id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } - else if (mirostat == 2) - { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(nullptr, &candidates_p, temp); - id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu); - } - else if (sampler_len > 0) - { - for (int i = 0; i < sampler_len; i++) { - switch (sampler_order[i]) { - case KCPP_SAMPLER_TOP_K: - llama_sample_top_k(nullptr, &candidates_p, top_k,1); - break; - case KCPP_SAMPLER_TOP_A: - sample_top_a(&candidates_p,top_a,1); - break; - case KCPP_SAMPLER_TOP_P: - llama_sample_top_p(nullptr, &candidates_p, top_p,1); - break; - case KCPP_SAMPLER_TFS: - llama_sample_tail_free(nullptr, &candidates_p, tfs,1); - break; - case KCPP_SAMPLER_TYP: - llama_sample_typical(nullptr, &candidates_p, typical_p,1); - break; - case KCPP_SAMPLER_TEMP: - llama_sample_temperature(nullptr, &candidates_p, temp); - break; - case KCPP_SAMPLER_REP_PEN: - apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p); - break; - default: - break; - } - } - id = sample_token(&candidates_p, rng); - } else { - // Temperature sampling - llama_sample_top_k(nullptr, &candidates_p, top_k,1); - sample_top_a(&candidates_p,top_a,1); - llama_sample_tail_free(nullptr, &candidates_p, tfs,1); - llama_sample_typical(nullptr, &candidates_p, typical_p,1); - llama_sample_top_p(nullptr, &candidates_p, top_p,1); - llama_sample_temperature(nullptr, &candidates_p, temp); - id = sample_token(&candidates_p, rng); + id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu); } } + else + { + for (int i = 0; i < sampler_order.size(); i++) + { + switch (sampler_order[i]) + { + case KCPP_SAMPLER_TOP_K: + llama_sample_top_k(nullptr, &candidates_p, top_k,1); + break; + case KCPP_SAMPLER_TOP_A: + sample_top_a(&candidates_p,top_a,1); + break; + case KCPP_SAMPLER_TOP_P: + llama_sample_top_p(nullptr, &candidates_p, top_p,1); + break; + case KCPP_SAMPLER_TFS: + llama_sample_tail_free(nullptr, &candidates_p, tfs,1); + break; + case KCPP_SAMPLER_TYP: + llama_sample_typical(nullptr, &candidates_p, typical_p,1); + break; + case KCPP_SAMPLER_TEMP: + sample_temperature(&candidates_p, temp); + break; + case KCPP_SAMPLER_REP_PEN: + sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p); + break; + default: + printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]); + break; + } + } + id = sample_token(&candidates_p, rng); + } return id; } @@ -952,6 +943,28 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o std::mt19937 rng(params.seed); concat_output = ""; + //prepare sampler order + std::vector sampler_order; + if(inputs.sampler_len<=0) //list by value + { + sampler_order = { + KCPP_SAMPLER_REP_PEN, + KCPP_SAMPLER_TOP_K, + KCPP_SAMPLER_TOP_A, + KCPP_SAMPLER_TFS, + KCPP_SAMPLER_TYP, + KCPP_SAMPLER_TOP_P, + KCPP_SAMPLER_TEMP + }; + } + else + { + for(int i=0;i