From abfdfb702e483c98d7644813c61ff77ee38b49c6 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 27 May 2023 17:32:37 +0800 Subject: [PATCH] added top_a sampler --- expose.h | 1 + gpttype_adapter.cpp | 38 +++++++++++++++++++++++++++++++++++--- koboldcpp.py | 10 +++++++--- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/expose.h b/expose.h index 99ddf892c..cb475f141 100644 --- a/expose.h +++ b/expose.h @@ -29,6 +29,7 @@ struct generation_inputs const int max_length; const float temperature; const int top_k; + const float top_a = 0.0f; const float top_p; const float typical_p; const float tfs; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 55a4668e4..d82175100 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -175,7 +175,37 @@ llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::m return X; } -int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, +// Top-a (remove all tokens that have softmax probability less than top_a*m^2 where m is the maximum softmax probability) +// top-a 0 is off (no effect) +void sample_top_a(llama_token_data_array * candidates, float a) { + if (a <= 0.0f || candidates->size<=1) { + return; + } + + llama_sample_softmax(nullptr, candidates); + + // Compute the cumulative probabilities + float maxprob = candidates->data[0].p; + + float threshold = a * maxprob * maxprob; //tokens with probs less than this are removed + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + // Go until we reach a value under the threshold + float checkprob = candidates->data[i].p; + if (checkprob < threshold) { + last_idx = i; + break; + } + } + // printf("\n\nCandidates: %d, A:%f, MaxProb: %f, Threshold: %f, LastIdx: %d",candidates->size,a,maxprob,threshold,last_idx); + // printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p); + + // Resize the output vector to keep only the selected tokens + candidates->size = last_idx; +} + +int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, int mirostat, float mirostat_tau, float mirostat_eta) { int id = 0; @@ -221,6 +251,7 @@ int mirostat, float mirostat_tau, float mirostat_eta) { // Temperature sampling llama_sample_top_k(nullptr, &candidates_p, top_k,1); + sample_top_a(&candidates_p,top_a); llama_sample_tail_free(nullptr, &candidates_p, tfs,1); llama_sample_typical(nullptr, &candidates_p, typical_p,1); llama_sample_top_p(nullptr, &candidates_p, top_p,1); @@ -659,7 +690,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } if (params.top_k < 1) { - params.top_k = 200; //to disable top_k we actually need to increase this value to a very high number + params.top_k = 300; //to disable top_k we actually need to increase this value to a very high number } if (params.seed <= 0) { @@ -937,6 +968,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o const float top_k = params.top_k; const float top_p = params.top_p; const float temp = params.temp; + const float top_a = inputs.top_a; const float repeat_penalty = params.repeat_penalty; const float typical_p = params.typical_p; const float tfs_z = params.tfs_z; @@ -1015,7 +1047,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, - top_k, top_p, typical_p, tfs_z, temp, rng, + top_k, top_a, top_p, typical_p, tfs_z, temp, rng, params.mirostat,params.mirostat_tau,params.mirostat_eta); last_n_tokens.erase(last_n_tokens.begin()); diff --git a/koboldcpp.py b/koboldcpp.py index 3fca437fd..5446477a8 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -35,6 +35,7 @@ class generation_inputs(ctypes.Structure): ("max_length", ctypes.c_int), ("temperature", ctypes.c_float), ("top_k", ctypes.c_int), + ("top_a", ctypes.c_float), ("top_p", ctypes.c_float), ("typical_p", ctypes.c_float), ("tfs", ctypes.c_float), @@ -161,7 +162,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]): +def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=300, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -169,6 +170,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.max_length = max_length inputs.temperature = temperature inputs.top_k = top_k + inputs.top_a = top_a inputs.top_p = top_p inputs.typical_p = typical_p inputs.tfs = tfs @@ -326,7 +328,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): max_context_length=genparams.get('max_context_length', maxctx), max_length=genparams.get('max_length', 50), temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 200), + top_k=genparams.get('top_k', 300), + top_a=genparams.get('top_a', 0.0), top_p=genparams.get('top_p', 0.85), typical_p=genparams.get('typical', 1.0), tfs=genparams.get('tfs', 1.0), @@ -342,7 +345,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): prompt=newprompt, max_length=genparams.get('max', 50), temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 200), + top_k=genparams.get('top_k', 300), + top_a=genparams.get('top_a', 0.0), top_p=genparams.get('top_p', 0.85), typical_p=genparams.get('typical', 1.0), tfs=genparams.get('tfs', 1.0),