diff --git a/expose.h b/expose.h index 9f686f75c..3e17778d7 100644 --- a/expose.h +++ b/expose.h @@ -76,6 +76,7 @@ struct generation_inputs const bool stream_sse; const char * grammar; const bool grammar_retain_state; + const bool quiet = false; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 67b8fe705..e755043b3 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1442,6 +1442,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_threads_batch = n_blasthreads; bool stream_sse = inputs.stream_sse; + bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1; + generation_finished = false; // Set current generation status generated_tokens.clear(); // New Generation, new tokens @@ -1695,7 +1697,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nBanned a total of %zu tokens.\n",banned_token_ids.size()); } - if(debugmode!=-1) + if(allow_regular_prints) { printf("\n"); } @@ -1716,7 +1718,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o // predict unsigned int embdsize = embd.size(); //print progress - if (!startedsampling && debugmode!=-1) + if (!startedsampling && allow_regular_prints) { printf("\rProcessing Prompt%s (%d / %zu tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size()); } @@ -1835,7 +1837,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_threads = original_threads; time1 = timer_check(); timer_start(); - if(debugmode!=-1) + if(allow_regular_prints) { printf("\n"); } @@ -1910,7 +1912,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o concat_output_mtx.unlock(); } - if (startedsampling && debugmode!=-1) + if (startedsampling && allow_regular_prints) { printf("\rGenerating (%d / %d tokens)", (params.n_predict - remaining_tokens), params.n_predict); } @@ -1935,7 +1937,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if(inputs.unban_tokens_rt && id==eosID) { stopper_unused_tokens = remaining_tokens; - if(debugmode!=-1) + if(allow_regular_prints) { printf("\n(EOS token triggered!)"); } @@ -1949,7 +1951,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { stopper_unused_tokens = remaining_tokens; remaining_tokens = 0; - if(debugmode!=-1) + if(allow_regular_prints) { auto match_clean = matched; replace_all(match_clean, "\n", "\\n"); diff --git a/koboldcpp.py b/koboldcpp.py index a10eecf1a..90bc68e4d 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -70,7 +70,8 @@ class generation_inputs(ctypes.Structure): ("stop_sequence", ctypes.c_char_p * stop_token_max), ("stream_sse", ctypes.c_bool), ("grammar", ctypes.c_char_p), - ("grammar_retain_state", ctypes.c_bool)] + ("grammar_retain_state", ctypes.c_bool), + ("quiet", ctypes.c_bool)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -299,7 +300,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False): +def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False): global maxctx, args, currentusergenkey, totalgens inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) @@ -323,6 +324,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range inputs.stream_sse = stream_sse + inputs.quiet = quiet inputs.grammar = grammar.encode("UTF-8") inputs.grammar_retain_state = grammar_retain_state inputs.unban_tokens_rt = not use_default_badwordsids @@ -425,6 +427,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): async def generate_text(self, genparams, api_format, stream_flag): global friendlymodelname + is_quiet = genparams.get('quiet', False) def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat if api_format==1: genparams["prompt"] = genparams.get('text', "") @@ -503,7 +506,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): grammar=genparams.get('grammar', ''), grammar_retain_state = genparams.get('grammar_retain_state', False), genkey=genparams.get('genkey', ''), - trimstop=genparams.get('trim_stop', False)) + trimstop=genparams.get('trim_stop', False), + quiet=is_quiet) recvtxt = "" if stream_flag: @@ -513,7 +517,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): else: recvtxt = run_blocking() - if args.debugmode!=-1: + if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1: utfprint("\nOutput: " + recvtxt) if api_format==1: @@ -809,7 +813,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): utfprint("Body Err: " + str(body)) return self.send_response(503) - if args.debugmode!=-1: + is_quiet = genparams.get('quiet', False) + if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1: utfprint("\nInput: " + json.dumps(genparams)) if args.foreground: