Do not default to Repetition Penalty 1.1 (#615)
* Do not default to Repetition Penalty * apply all known aliases for repetition penalty when using the OAI endpoint. rep pen defaults to 1, range to 256 --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
b9ad08af19
commit
bd77a48037
1 changed files with 13 additions and 5 deletions
18
koboldcpp.py
18
koboldcpp.py
|
@ -311,7 +311,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, logit_biases={}):
|
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.0, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, dynatemp_range=0.0, logit_biases={}):
|
||||||
global maxctx, args, currentusergenkey, totalgens
|
global maxctx, args, currentusergenkey, totalgens
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
|
@ -468,6 +468,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
global friendlymodelname
|
global friendlymodelname
|
||||||
is_quiet = args.quiet
|
is_quiet = args.quiet
|
||||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||||
|
|
||||||
|
#alias all nonstandard alternative names for rep pen.
|
||||||
|
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||||
|
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||||
|
rp3 = genparams.get('rep_pen', 1.0)
|
||||||
|
rp_max = max(rp1,rp2,rp3)
|
||||||
|
genparams["rep_pen"] = rp_max
|
||||||
|
|
||||||
if api_format==1:
|
if api_format==1:
|
||||||
genparams["prompt"] = genparams.get('text', "")
|
genparams["prompt"] = genparams.get('text', "")
|
||||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||||
|
@ -477,8 +485,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||||
genparams["presence_penalty"] = presence_penalty
|
genparams["presence_penalty"] = presence_penalty
|
||||||
if presence_penalty > 0 and (genparams.get('rep_pen', 0)==0):
|
if presence_penalty > 0:
|
||||||
genparams["rep_pen"] = 1.0
|
genparams["rep_pen"] = 1.0 #disable rep pen if presence pen is specified for OAI
|
||||||
# openai allows either a string or a list as a stop sequence
|
# openai allows either a string or a list as a stop sequence
|
||||||
if isinstance(genparams.get('stop',[]), list):
|
if isinstance(genparams.get('stop',[]), list):
|
||||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||||
|
@ -533,7 +541,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
min_p=genparams.get('min_p', 0.0),
|
min_p=genparams.get('min_p', 0.0),
|
||||||
typical_p=genparams.get('typical', 1.0),
|
typical_p=genparams.get('typical', 1.0),
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.0),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 256),
|
rep_pen_range=genparams.get('rep_pen_range', 256),
|
||||||
presence_penalty=genparams.get('presence_penalty', 0.0),
|
presence_penalty=genparams.get('presence_penalty', 0.0),
|
||||||
mirostat=genparams.get('mirostat', 0),
|
mirostat=genparams.get('mirostat', 0),
|
||||||
|
@ -687,7 +695,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
temperature = float(parsed_dict['temperature'][0]) if 'temperature' in parsed_dict else 0.7
|
temperature = float(parsed_dict['temperature'][0]) if 'temperature' in parsed_dict else 0.7
|
||||||
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
|
top_k = int(parsed_dict['top_k'][0]) if 'top_k' in parsed_dict else 100
|
||||||
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
|
top_p = float(parsed_dict['top_p'][0]) if 'top_p' in parsed_dict else 0.9
|
||||||
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.1
|
rep_pen = float(parsed_dict['rep_pen'][0]) if 'rep_pen' in parsed_dict else 1.0
|
||||||
use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0
|
use_default_badwordsids = int(parsed_dict['use_default_badwordsids'][0]) if 'use_default_badwordsids' in parsed_dict else 0
|
||||||
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
|
gencommand = (parsed_dict['generate'][0] if 'generate' in parsed_dict else "")=="Generate"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue