diff --git a/expose.h b/expose.h index 6e7c54cb4..2c8cc98a2 100644 --- a/expose.h +++ b/expose.h @@ -1,6 +1,7 @@ #pragma once const int stop_token_max = 10; +const int ban_token_max = 10; // match kobold's sampler list and order enum samplers { @@ -35,6 +36,7 @@ struct load_model_inputs const int debugmode = 0; const int forceversion = 0; const int gpulayers = 0; + const char * banned_tokens[ban_token_max]; }; struct generation_inputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 24686d2f3..b876f8441 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -76,6 +76,8 @@ static size_t mem_per_token = 0; static std::vector logits; static std::vector smartcontext; static std::vector stop_sequence; +static std::vector banned_tokens; +static std::vector banned_token_ids; static std::vector top_picks; static int remaining_tokens = 0; static int stopper_unused_tokens = 0; @@ -344,6 +346,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in = gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx = mpt_ctx_v3.hparams.n_ctx = params.n_ctx; + //handle custom token bans + banned_tokens.clear(); + for(int x=0;x0) + { + printf("\n[First Run] Banning %d token sequences...",banned_tokens.size()); + for(int v=0;v0) + { + for(int t=0;t0) + { + int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); + for (int t = 0; t < btsize; ++t) + { + logits[banned_token_ids[t]] = (logits[topid] < 0 ? logits[topid] : 0); + } + } } id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, diff --git a/koboldcpp.py b/koboldcpp.py index b523a8599..a685df1cc 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 sampler_order_max = 7 +ban_token_max = 10 class load_model_inputs(ctypes.Structure): _fields_ = [("threads", ctypes.c_int), @@ -34,7 +35,8 @@ class load_model_inputs(ctypes.Structure): ("blasbatchsize", ctypes.c_int), ("debugmode", ctypes.c_int), ("forceversion", ctypes.c_int), - ("gpulayers", ctypes.c_int)] + ("gpulayers", ctypes.c_int), + ("banned_tokens", ctypes.c_char_p * ban_token_max)] class generation_inputs(ctypes.Structure): _fields_ = [("seed", ctypes.c_int), @@ -195,6 +197,13 @@ def load_model(model_filename): inputs.cublas_info = 2 inputs.executable_path = (getdirpath()+"/").encode("UTF-8") inputs.debugmode = args.debugmode + banned_tokens = args.bantokens + print(banned_tokens) + for n in range(ban_token_max): + if not banned_tokens or n >= len(banned_tokens): + inputs.banned_tokens[n] = "".encode("UTF-8") + else: + inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8") ret = handle.load_model(inputs) return ret @@ -1297,7 +1306,7 @@ if __name__ == '__main__': parser.add_argument("--stream", help="Uses streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true') parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true') parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents the EOS token from being generated. This flag unbans it.", action='store_true') - parser.add_argument("--bantokens", help="You can manually specify a list of token IDs that the AI cannot use.", metavar=('[elements]'), nargs='+') + parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+') parser.add_argument("--usemirostat", help="Experimental! Replaces your samplers with mirostat. Takes 3 params = [type(0/1/2), tau(5.0), eta(0.1)].",metavar=('[type]', '[tau]', '[eta]'), type=float, nargs=3) parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0) parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')