diff --git a/koboldcpp.py b/koboldcpp.py index 1e5130a17..f942ec066 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -161,13 +161,12 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.tfs = tfs inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range - inputs.seed = seed - if stop_sequence: - for n in range(0,stop_token_max): - if n >= len(stop_sequence): - inputs.stop_sequence[n] = "".encode("UTF-8") - else: - inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8") + inputs.seed = seed + for n in range(0,stop_token_max): + if not stop_sequence or n >= len(stop_sequence): + inputs.stop_sequence[n] = "".encode("UTF-8") + else: + inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8") ret = handle.generate(inputs,outputs) if(ret.status==1): return ret.text.decode("UTF-8","ignore")