add presence penalty
This commit is contained in:
parent
da2db0302c
commit
3f863eed72
3 changed files with 18 additions and 9 deletions
1
expose.h
1
expose.h
|
@ -65,6 +65,7 @@ struct generation_inputs
|
||||||
const float tfs;
|
const float tfs;
|
||||||
const float rep_pen;
|
const float rep_pen;
|
||||||
const int rep_pen_range;
|
const int rep_pen_range;
|
||||||
|
const float presence_penalty = 0.0f;
|
||||||
const int mirostat = 0;
|
const int mirostat = 0;
|
||||||
const float mirostat_eta;
|
const float mirostat_eta;
|
||||||
const float mirostat_tau;
|
const float mirostat_tau;
|
||||||
|
|
|
@ -386,7 +386,7 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
|
||||||
candidates->size = last_idx;
|
candidates->size = last_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
|
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float presence_penalty, llama_token_data_array * candidates_p)
|
||||||
{
|
{
|
||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
|
||||||
|
|
||||||
|
@ -414,6 +414,8 @@ void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_dat
|
||||||
} else {
|
} else {
|
||||||
candidates->data[i].logit /= penalty;
|
candidates->data[i].logit /= penalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
candidates->data[i].logit -= presence_penalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates->sorted = false;
|
candidates->sorted = false;
|
||||||
|
@ -474,7 +476,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
||||||
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
|
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
|
||||||
{
|
{
|
||||||
int id = 0;
|
int id = 0;
|
||||||
|
@ -494,7 +496,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
|
||||||
{
|
{
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
const int mirostat_m = 100;
|
const int mirostat_m = 100;
|
||||||
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
|
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
|
||||||
sample_temperature(&candidates_p, temp);
|
sample_temperature(&candidates_p, temp);
|
||||||
if (mirostat == 1)
|
if (mirostat == 1)
|
||||||
{
|
{
|
||||||
|
@ -531,7 +533,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
|
||||||
sample_temperature(&candidates_p, temp);
|
sample_temperature(&candidates_p, temp);
|
||||||
break;
|
break;
|
||||||
case KCPP_SAMPLER_REP_PEN:
|
case KCPP_SAMPLER_REP_PEN:
|
||||||
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
|
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
|
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
|
||||||
|
@ -1442,6 +1444,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
params.temp = inputs.temperature;
|
params.temp = inputs.temperature;
|
||||||
params.repeat_last_n = inputs.rep_pen_range;
|
params.repeat_last_n = inputs.rep_pen_range;
|
||||||
params.repeat_penalty = inputs.rep_pen;
|
params.repeat_penalty = inputs.rep_pen;
|
||||||
|
params.presence_penalty = inputs.presence_penalty;
|
||||||
params.mirostat = inputs.mirostat;
|
params.mirostat = inputs.mirostat;
|
||||||
params.mirostat_eta = inputs.mirostat_eta;
|
params.mirostat_eta = inputs.mirostat_eta;
|
||||||
params.mirostat_tau = inputs.mirostat_tau;
|
params.mirostat_tau = inputs.mirostat_tau;
|
||||||
|
@ -1836,6 +1839,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const float top_a = inputs.top_a;
|
const float top_a = inputs.top_a;
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
|
const float presence_penalty = params.presence_penalty;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
|
|
||||||
|
@ -1891,7 +1895,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty,
|
||||||
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
|
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
|
||||||
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar);
|
params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order, grammar);
|
||||||
|
|
||||||
|
|
12
koboldcpp.py
12
koboldcpp.py
|
@ -60,6 +60,7 @@ class generation_inputs(ctypes.Structure):
|
||||||
("tfs", ctypes.c_float),
|
("tfs", ctypes.c_float),
|
||||||
("rep_pen", ctypes.c_float),
|
("rep_pen", ctypes.c_float),
|
||||||
("rep_pen_range", ctypes.c_int),
|
("rep_pen_range", ctypes.c_int),
|
||||||
|
("presence_penalty", ctypes.c_float),
|
||||||
("mirostat", ctypes.c_int),
|
("mirostat", ctypes.c_int),
|
||||||
("mirostat_tau", ctypes.c_float),
|
("mirostat_tau", ctypes.c_float),
|
||||||
("mirostat_eta", ctypes.c_float),
|
("mirostat_eta", ctypes.c_float),
|
||||||
|
@ -302,7 +303,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
|
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
|
||||||
global maxctx, args, currentusergenkey, totalgens
|
global maxctx, args, currentusergenkey, totalgens
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
|
@ -327,6 +328,7 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
|
||||||
inputs.tfs = tfs
|
inputs.tfs = tfs
|
||||||
inputs.rep_pen = rep_pen
|
inputs.rep_pen = rep_pen
|
||||||
inputs.rep_pen_range = rep_pen_range
|
inputs.rep_pen_range = rep_pen_range
|
||||||
|
inputs.presence_penalty = presence_penalty
|
||||||
inputs.stream_sse = stream_sse
|
inputs.stream_sse = stream_sse
|
||||||
inputs.quiet = quiet
|
inputs.quiet = quiet
|
||||||
inputs.grammar = grammar.encode("UTF-8")
|
inputs.grammar = grammar.encode("UTF-8")
|
||||||
|
@ -440,10 +442,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
genparams["max_length"] = genparams.get('max', 100)
|
genparams["max_length"] = genparams.get('max', 100)
|
||||||
|
|
||||||
elif api_format==3 or api_format==4:
|
elif api_format==3 or api_format==4:
|
||||||
frqp = genparams.get('frequency_penalty', 0.1)
|
|
||||||
scaled_rep_pen = genparams.get('presence_penalty', frqp) + 1
|
|
||||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||||
genparams["rep_pen"] = scaled_rep_pen
|
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||||
|
genparams["presence_penalty"] = presence_penalty
|
||||||
|
if presence_penalty > 0:
|
||||||
|
genparams["rep_pen"] = 1.0
|
||||||
# openai allows either a string or a list as a stop sequence
|
# openai allows either a string or a list as a stop sequence
|
||||||
if isinstance(genparams.get('stop',[]), list):
|
if isinstance(genparams.get('stop',[]), list):
|
||||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||||
|
@ -500,6 +503,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 256),
|
rep_pen_range=genparams.get('rep_pen_range', 256),
|
||||||
|
presence_penalty=genparams.get('presence_penalty', 0.0),
|
||||||
mirostat=genparams.get('mirostat', 0),
|
mirostat=genparams.get('mirostat', 0),
|
||||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue