fixed stop sequence crash

This commit is contained in:
Concedo 2023-05-02 14:56:50 +08:00
parent 94827172e0
commit 6f702f2700

View file

@ -162,12 +162,11 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
inputs.rep_pen = rep_pen inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range inputs.rep_pen_range = rep_pen_range
inputs.seed = seed inputs.seed = seed
if stop_sequence: for n in range(0,stop_token_max):
for n in range(0,stop_token_max): if not stop_sequence or n >= len(stop_sequence):
if n >= len(stop_sequence): inputs.stop_sequence[n] = "".encode("UTF-8")
inputs.stop_sequence[n] = "".encode("UTF-8") else:
else: inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
ret = handle.generate(inputs,outputs) ret = handle.generate(inputs,outputs)
if(ret.status==1): if(ret.status==1):
return ret.text.decode("UTF-8","ignore") return ret.text.decode("UTF-8","ignore")