diff --git a/expose.h b/expose.h index b74718eb9..89471f8c2 100644 --- a/expose.h +++ b/expose.h @@ -1,6 +1,18 @@ #pragma once const int stop_token_max = 10; +// match kobold's sampler list and order +enum samplers +{ + KCPP_SAMPLER_TOP_K, + KCPP_SAMPLER_TOP_A, + KCPP_SAMPLER_TOP_P, + KCPP_SAMPLER_TFS, + KCPP_SAMPLER_TYP, + KCPP_SAMPLER_TEMP, + KCPP_SAMPLER_REP_PEN, + KCPP_SAMPLER_MAX +}; struct load_model_inputs { const int threads; @@ -40,6 +52,8 @@ struct generation_inputs const int mirostat = 0; const float mirostat_eta; const float mirostat_tau; + const samplers sampler_order[KCPP_SAMPLER_MAX]; + const int sampler_len; const char * stop_sequence[stop_token_max]; const bool stream_sse; }; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 8d49b67b3..396e53dd5 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -219,8 +219,16 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) candidates->size = last_idx; } +void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p) +{ + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); + llama_sample_repetition_penalty(nullptr, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, rep_pen); +} + int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, -int mirostat, float mirostat_tau, float mirostat_eta) +int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX]) { int id = 0; std::vector candidates; @@ -231,11 +239,11 @@ int mirostat, float mirostat_tau, float mirostat_eta) llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // Apply penalties - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); - llama_sample_repetition_penalty(nullptr, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, rep_pen); + // Run this except for when we are going to do the sampler reordering case below + if (temp <= 0 || mirostat > 0 || sampler_len == 0) + { + apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p); + } // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -261,6 +269,37 @@ int mirostat, float mirostat_tau, float mirostat_eta) llama_sample_temperature(nullptr, &candidates_p, temp); id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu); } + else if (sampler_len > 0) + { + for (int i = 0; i < sampler_len; i++) { + switch (sampler_order[i]) { + case KCPP_SAMPLER_TOP_K: + llama_sample_top_k(nullptr, &candidates_p, top_k,1); + break; + case KCPP_SAMPLER_TOP_A: + sample_top_a(&candidates_p,top_a,1); + break; + case KCPP_SAMPLER_TOP_P: + llama_sample_top_p(nullptr, &candidates_p, top_p,1); + break; + case KCPP_SAMPLER_TFS: + llama_sample_tail_free(nullptr, &candidates_p, tfs,1); + break; + case KCPP_SAMPLER_TYP: + llama_sample_typical(nullptr, &candidates_p, typical_p,1); + break; + case KCPP_SAMPLER_TEMP: + llama_sample_temperature(nullptr, &candidates_p, temp); + break; + case KCPP_SAMPLER_REP_PEN: + apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p); + break; + default: + break; + } + } + id = sample_token(&candidates_p, rng); + } else { // Temperature sampling @@ -1235,7 +1274,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_a, top_p, typical_p, tfs_z, temp, rng, - params.mirostat,params.mirostat_tau,params.mirostat_eta); + params.mirostat, params.mirostat_tau, params.mirostat_eta, + inputs.sampler_len, inputs.sampler_order); last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); diff --git a/koboldcpp.py b/koboldcpp.py index 2041b0d24..284463441 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -9,6 +9,7 @@ import json, sys, http.server, time, asyncio, socket, threading from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 +sampler_order_max = 7 class load_model_inputs(ctypes.Structure): _fields_ = [("threads", ctypes.c_int), @@ -47,6 +48,8 @@ class generation_inputs(ctypes.Structure): ("mirostat", ctypes.c_int), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), + ("sampler_order", ctypes.c_int * sampler_order_max), + ("sampler_len", ctypes.c_int), ("stop_sequence", ctypes.c_char_p * stop_token_max), ("stream_sse", ctypes.c_bool)] @@ -186,7 +189,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False): +def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -205,8 +208,19 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.mirostat = int(args.usemirostat[0]) inputs.mirostat_tau = float(args.usemirostat[1]) inputs.mirostat_eta = float(args.usemirostat[2]) + elif mirostat in (1, 2): + inputs.mirostat = mirostat + inputs.mirostat_tau = mirostat_tau + inputs.mirostat_eta = mirostat_eta else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 + if sampler_order and 0 < len(sampler_order) <= sampler_order_max: + try: + for i, sampler in enumerate(sampler_order): + inputs.sampler_order[i] = sampler + inputs.sampler_len = len(sampler_order) + except TypeError as e: + print("ERROR: sampler_order must be a list of integers: " + str(e)) inputs.seed = seed for n in range(stop_token_max): if not stop_sequence or n >= len(stop_sequence): @@ -272,6 +286,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tfs=genparams.get('tfs', 1.0), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), + mirostat=genparams.get('mirostat', 0), + mirostat_tau=genparams.get('mirostat_tau', 5.0), + mirostat_eta=genparams.get('mirostat_eta', 0.1), + sampler_order=genparams.get('sampler_order', None), seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), stream_sse=stream_flag) @@ -288,6 +306,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tfs=genparams.get('tfs', 1.0), rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), + mirostat=genparams.get('mirostat', 0), + mirostat_tau=genparams.get('mirostat_tau', 5.0), + mirostat_eta=genparams.get('mirostat_eta', 0.1), + sampler_order=genparams.get('sampler_order', None), seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), stream_sse=stream_flag)