diff --git a/expose.h b/expose.h index 0905bc3ec..a6768980d 100644 --- a/expose.h +++ b/expose.h @@ -69,6 +69,7 @@ struct generation_inputs const float mirostat_tau; const samplers sampler_order[KCPP_SAMPLER_MAX]; const int sampler_len; + const bool unban_tokens_rt; const char * stop_sequence[stop_token_max]; const bool stream_sse; }; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 79203e1a4..83ee387d4 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1458,7 +1458,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } float lowestLogit = LowestLogit(logitsPtr,n_vocab); - if (!unbanTokens) + if (!unbanTokens && !inputs.unban_tokens_rt) { // set the logit of the eos token (2) to -INF to avoid sampling it logitsPtr[eosID] = lowestLogit; @@ -1476,7 +1476,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { logitsPtr = logits.data(); float lowestLogit = LowestLogit(logits); - if (!unbanTokens) + if (!unbanTokens && !inputs.unban_tokens_rt) { //gpt2 uses negative logits, so we cant zero it // set the logit of the eos token to minimum to avoid sampling it @@ -1580,7 +1580,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("]\n"); } - if(unbanTokens && id==eosID) + if((unbanTokens||inputs.unban_tokens_rt) && id==eosID) { stopper_unused_tokens = remaining_tokens; printf("\n(EOS token triggered!)"); diff --git a/koboldcpp.py b/koboldcpp.py index 6289ffd1b..4c82625c8 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -61,6 +61,7 @@ class generation_inputs(ctypes.Structure): ("mirostat_eta", ctypes.c_float), ("sampler_order", ctypes.c_int * sampler_order_max), ("sampler_len", ctypes.c_int), + ("unban_tokens_rt", ctypes.c_bool), ("stop_sequence", ctypes.c_char_p * stop_token_max), ("stream_sse", ctypes.c_bool)] @@ -249,7 +250,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], stream_sse=False): +def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordids=True, stream_sse=False): global maxctx, args inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) @@ -271,6 +272,7 @@ def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_ inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.stream_sse = stream_sse + inputs.unban_tokens_rt = not use_default_badwordids if args.usemirostat and args.usemirostat[0]>0: inputs.mirostat = int(args.usemirostat[0]) inputs.mirostat_tau = float(args.usemirostat[1]) @@ -368,6 +370,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), + use_default_badwordids=genparams.get('use_default_badwordids', True), stream_sse=stream_flag) else: @@ -388,6 +391,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []), + use_default_badwordids=genparams.get('use_default_badwordids', True), stream_sse=stream_flag) recvtxt = "" @@ -505,7 +509,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): response_body = (json.dumps({"values": []}).encode()) elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): - response_body = (json.dumps({"result":"1.2.2"}).encode()) + response_body = (json.dumps({"result":"1.2.4"}).encode()) elif self.path.endswith(('/api/extra/version')): response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())