From 3962eb39c74d301fb9845d9d0a0d436bb5bd771e Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 24 Apr 2023 21:50:20 +0800 Subject: [PATCH] added token unbanning --- expose.h | 1 + gpttype_adapter.cpp | 42 +++++++++++++++++++++++++++--------------- koboldcpp.py | 5 ++++- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/expose.h b/expose.h index 34f184924..b5545b4d0 100644 --- a/expose.h +++ b/expose.h @@ -12,6 +12,7 @@ struct load_model_inputs const char * lora_filename; const bool use_mmap; const bool use_smartcontext; + const bool unban_tokens; const int clblast_info = 0; const int blasbatchsize = 512; }; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 039237aa1..df185405f 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -41,6 +41,7 @@ static int n_past = 0; static int n_threads = 4; static int n_batch = 8; static bool useSmartContext = false; +static bool unbanTokens = false; static int blasbatchsize = 512; static std::string modelname; static std::vector last_n_tokens; @@ -65,6 +66,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in n_batch = params.n_batch = inputs.batch_size; modelname = params.model = inputs.model_filename; useSmartContext = inputs.use_smartcontext; + unbanTokens = inputs.unban_tokens; blasbatchsize = inputs.blasbatchsize; params.memory_f16 = inputs.f16_kv; params.n_ctx = inputs.max_context_length; @@ -366,7 +368,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } params.n_batch = bbs; //received reports of 1024 and above crashing on some models - //params.n_threads = 1; //do not limit here anymore. + if(!ggml_cpu_has_cublas()) + { + params.n_threads = 1; //do not limit here anymore. + } } current_context_tokens.resize(n_past); @@ -512,28 +517,35 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) { auto logits = llama_get_logits(llama_ctx_v1); - // set the logit of the eos token (2) to zero to avoid sampling it - logits[llama_token_eos()] = 0; - //set logits of opening square bracket to zero. - logits[518] = 0; - logits[29961] = 0; + + if (!unbanTokens) + { + // set the logit of the eos token (2) to zero to avoid sampling it + logits[llama_token_eos()] = 0; + //set logits of opening square bracket to zero. + logits[518] = 0; + logits[29961] = 0; + } id = llama_sample_top_p_top_k(llama_ctx_v1, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty); } else { - // set the logit of the eos token (2) to zero to avoid sampling it - if((file_format == FileFormat::GPT2_1 || - file_format == FileFormat::GPT2_2 || - file_format == FileFormat::GPTJ_1 || - file_format == FileFormat::GPTJ_2 || - file_format == FileFormat::GPTJ_3) - && logits.size()>50256) + if (!unbanTokens) { - logits[50256] = (logits[50256] < 0 ? logits[50256] : 0); + // set the logit of the eos token (2) to zero to avoid sampling it + if ((file_format == FileFormat::GPT2_1 || + file_format == FileFormat::GPT2_2 || + file_format == FileFormat::GPTJ_1 || + file_format == FileFormat::GPTJ_2 || + file_format == FileFormat::GPTJ_3) && + logits.size() > 50256) + { + logits[50256] = (logits[50256] < 0 ? logits[50256] : 0); + } + //gpt2 uses negative logits, so we cant zero it } - //gpt2 uses negative logits, so we cant zero it id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); } diff --git a/koboldcpp.py b/koboldcpp.py index 6c2e4e6d3..2a0245135 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure): ("lora_filename", ctypes.c_char_p), ("use_mmap", ctypes.c_bool), ("use_smartcontext", ctypes.c_bool), + ("unban_tokens", ctypes.c_bool), ("clblast_info", ctypes.c_int), ("blasbatchsize", ctypes.c_int)] @@ -96,9 +97,10 @@ def load_model(model_filename): inputs.batch_size = 8 inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten inputs.threads = args.threads - inputs.f16_kv = True + inputs.f16_kv = True inputs.use_mmap = (not args.nommap) inputs.use_smartcontext = args.smartcontext + inputs.unban_tokens = args.unbantokens inputs.blasbatchsize = args.blasbatchsize clblastids = 0 if args.useclblast: @@ -497,6 +499,7 @@ if __name__ == '__main__': parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512)", type=int,choices=[32,64,128,256,512,1024], default=512) parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true') parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true') + parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true') parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true') parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true') compatgroup = parser.add_mutually_exclusive_group()