refactor to avoid code duplication

This commit is contained in:
Concedo 2023-07-04 18:35:03 +08:00
parent 309534dcd0
commit c6c0afdf18
3 changed files with 91 additions and 79 deletions

View file

@ -4,13 +4,13 @@ const int stop_token_max = 10;
// match kobold's sampler list and order // match kobold's sampler list and order
enum samplers enum samplers
{ {
KCPP_SAMPLER_TOP_K, KCPP_SAMPLER_TOP_K=0,
KCPP_SAMPLER_TOP_A, KCPP_SAMPLER_TOP_A=1,
KCPP_SAMPLER_TOP_P, KCPP_SAMPLER_TOP_P=2,
KCPP_SAMPLER_TFS, KCPP_SAMPLER_TFS=3,
KCPP_SAMPLER_TYP, KCPP_SAMPLER_TYP=4,
KCPP_SAMPLER_TEMP, KCPP_SAMPLER_TEMP=5,
KCPP_SAMPLER_REP_PEN, KCPP_SAMPLER_REP_PEN=6,
KCPP_SAMPLER_MAX KCPP_SAMPLER_MAX
}; };
struct load_model_inputs struct load_model_inputs

View file

@ -219,16 +219,31 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx; candidates->size = last_idx;
} }
void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p) void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array * candidates_p)
{ {
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
llama_sample_repetition_penalty(nullptr, &candidates_p, llama_sample_repetition_penalty(nullptr, candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, rep_pen); last_n_repeat, rep_pen);
} }
void sample_temperature(llama_token_data_array * candidates_p, float temp)
{
if (temp <= 0)
{
// Imitate greedy sampling
temp = 0.01f; //cannot be zero else div0
llama_sample_temperature(nullptr, candidates_p, temp);
llama_sample_top_k(nullptr, candidates_p, 1, 1); //only want first candidate
}
else
{
llama_sample_temperature(nullptr, candidates_p, temp);
}
}
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng, int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX]) int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order)
{ {
int id = 0; int id = 0;
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -239,40 +254,27 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Run this except for when we are going to do the sampler reordering case below if (mirostat == 1 || mirostat == 2)
if (temp <= 0 || mirostat > 0 || sampler_len == 0)
{
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
}
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
if (temp <= 0)
{
// Greedy sampling
id = llama_sample_token_greedy(nullptr, &candidates_p);
}
else
{
if (mirostat == 1)
{ {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temperature(nullptr, &candidates_p, temp); sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
sample_temperature(&candidates_p, temp);
if (mirostat == 1)
{
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} }
else if (mirostat == 2) else
{ {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(nullptr, &candidates_p, temp);
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu); id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
} }
else if (sampler_len > 0) }
else
{
for (int i = 0; i < sampler_order.size(); i++)
{
switch (sampler_order[i])
{ {
for (int i = 0; i < sampler_len; i++) {
switch (sampler_order[i]) {
case KCPP_SAMPLER_TOP_K: case KCPP_SAMPLER_TOP_K:
llama_sample_top_k(nullptr, &candidates_p, top_k,1); llama_sample_top_k(nullptr, &candidates_p, top_k,1);
break; break;
@ -289,29 +291,18 @@ int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const sa
llama_sample_typical(nullptr, &candidates_p, typical_p,1); llama_sample_typical(nullptr, &candidates_p, typical_p,1);
break; break;
case KCPP_SAMPLER_TEMP: case KCPP_SAMPLER_TEMP:
llama_sample_temperature(nullptr, &candidates_p, temp); sample_temperature(&candidates_p, temp);
break; break;
case KCPP_SAMPLER_REP_PEN: case KCPP_SAMPLER_REP_PEN:
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p); sample_rep_pen(n_ctx, rep_pen_range, rep_pen, &candidates_p);
break; break;
default: default:
printf("\nSampleLogits: Unknown Sampler : %d",sampler_order[i]);
break; break;
} }
} }
id = sample_token(&candidates_p, rng); id = sample_token(&candidates_p, rng);
} }
else
{
// Temperature sampling
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
sample_top_a(&candidates_p,top_a,1);
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
llama_sample_temperature(nullptr, &candidates_p, temp);
id = sample_token(&candidates_p, rng);
}
}
return id; return id;
} }
@ -952,6 +943,28 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::mt19937 rng(params.seed); std::mt19937 rng(params.seed);
concat_output = ""; concat_output = "";
//prepare sampler order
std::vector<samplers> sampler_order;
if(inputs.sampler_len<=0) //list by value
{
sampler_order = {
KCPP_SAMPLER_REP_PEN,
KCPP_SAMPLER_TOP_K,
KCPP_SAMPLER_TOP_A,
KCPP_SAMPLER_TFS,
KCPP_SAMPLER_TYP,
KCPP_SAMPLER_TOP_P,
KCPP_SAMPLER_TEMP
};
}
else
{
for(int i=0;i<inputs.sampler_len;++i)
{
sampler_order.push_back(inputs.sampler_order[i]);
}
}
bool startedsampling = false; bool startedsampling = false;
bool use_scratch = true; //for normal inference always use scratch bool use_scratch = true; //for normal inference always use scratch
@ -1274,8 +1287,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
top_k, top_a, top_p, typical_p, tfs_z, temp, rng, top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
params.mirostat, params.mirostat_tau, params.mirostat_eta, params.mirostat, params.mirostat_tau, params.mirostat_eta, sampler_order);
inputs.sampler_len, inputs.sampler_order);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);

View file

@ -189,7 +189,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs) ret = handle.load_model(inputs)
return ret return ret
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False): def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], stream_sse=False):
inputs = generation_inputs() inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8") inputs.prompt = prompt.encode("UTF-8")
@ -289,7 +289,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
mirostat=genparams.get('mirostat', 0), mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0), mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1), mirostat_eta=genparams.get('mirostat_eta', 0.1),
sampler_order=genparams.get('sampler_order', None), sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=genparams.get('sampler_seed', -1), seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []), stop_sequence=genparams.get('stop_sequence', []),
stream_sse=stream_flag) stream_sse=stream_flag)
@ -309,7 +309,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
mirostat=genparams.get('mirostat', 0), mirostat=genparams.get('mirostat', 0),
mirostat_tau=genparams.get('mirostat_tau', 5.0), mirostat_tau=genparams.get('mirostat_tau', 5.0),
mirostat_eta=genparams.get('mirostat_eta', 0.1), mirostat_eta=genparams.get('mirostat_eta', 0.1),
sampler_order=genparams.get('sampler_order', None), sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]),
seed=genparams.get('sampler_seed', -1), seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', []), stop_sequence=genparams.get('stop_sequence', []),
stream_sse=stream_flag) stream_sse=stream_flag)