added trim_stop flag

This commit is contained in:
Concedo 2023-11-09 16:55:44 +08:00
parent afa466807d
commit 7ef4ec3b16

View file

@ -298,7 +298,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs) ret = handle.load_model(inputs)
return ret return ret
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''): def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False):
global maxctx, args, currentusergenkey, totalgens global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs() inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@ -351,9 +351,15 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
currentusergenkey = genkey currentusergenkey = genkey
totalgens += 1 totalgens += 1
ret = handle.generate(inputs,outputs) ret = handle.generate(inputs,outputs)
if(ret.status==1): outstr = ""
return ret.text.decode("UTF-8","ignore") if ret.status==1:
return "" outstr = ret.text.decode("UTF-8","ignore")
if trimstop:
for trim_str in stop_sequence:
sindex = outstr.find(trim_str)
if sindex != -1 and trim_str!="":
outstr = outstr[:sindex]
return outstr
def utfprint(str): def utfprint(str):
try: try:
@ -498,7 +504,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
stream_sse=stream_flag, stream_sse=stream_flag,
grammar=genparams.get('grammar', ''), grammar=genparams.get('grammar', ''),
grammar_retain_state = genparams.get('grammar_retain_state', False), grammar_retain_state = genparams.get('grammar_retain_state', False),
genkey=genparams.get('genkey', '')) genkey=genparams.get('genkey', ''),
trimstop=genparams.get('trim_stop', False))
recvtxt = "" recvtxt = ""
if stream_flag: if stream_flag: