integrated the new samplers
This commit is contained in:
parent
da0c34b028
commit
7afad2b9b5
3 changed files with 71 additions and 4 deletions
2
expose.h
2
expose.h
|
@ -26,6 +26,8 @@ struct generation_inputs
|
|||
const float temperature;
|
||||
const int top_k;
|
||||
const float top_p;
|
||||
const float typical_p;
|
||||
const float tfs;
|
||||
const float rep_pen;
|
||||
const int rep_pen_range;
|
||||
const char * stop_sequence[stop_token_max];
|
||||
|
|
|
@ -78,6 +78,59 @@ inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
llama_sample_softmax(nullptr, candidates);
|
||||
std::vector<float> probs;
|
||||
probs.reserve(candidates->size);
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
probs.push_back(candidates->data[i].p);
|
||||
}
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
llama_token result = candidates->data[idx].id;
|
||||
return result;
|
||||
}
|
||||
|
||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng)
|
||||
{
|
||||
int id = 0;
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// Apply penalties
|
||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, rep_pen);
|
||||
|
||||
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
|
||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
// last_n_repeat, alpha_frequency, alpha_presence);
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(nullptr, &candidates_p);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(nullptr, &candidates_p, top_k);
|
||||
llama_sample_tail_free(nullptr, &candidates_p, tfs);
|
||||
llama_sample_typical(nullptr, &candidates_p, typical_p);
|
||||
llama_sample_top_p(nullptr, &candidates_p, top_p);
|
||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||
id = sample_token(&candidates_p, rng);
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
||||
{
|
||||
ggml_time_init();
|
||||
|
@ -311,6 +364,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
params.n_predict = inputs.max_length;
|
||||
params.top_k = inputs.top_k;
|
||||
params.top_p = inputs.top_p;
|
||||
params.typical_p = inputs.typical_p;
|
||||
params.tfs_z = inputs.tfs;
|
||||
params.temp = inputs.temperature;
|
||||
params.repeat_last_n = inputs.rep_pen_range;
|
||||
params.repeat_penalty = inputs.rep_pen;
|
||||
|
@ -423,7 +478,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
{
|
||||
//do nothing
|
||||
n_vocab = llama_n_vocab(llama_ctx_v1);
|
||||
}
|
||||
else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2)
|
||||
{
|
||||
|
@ -557,6 +612,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
const float top_p = params.top_p;
|
||||
const float temp = params.temp;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float typical_p = params.typical_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
|
||||
if (!startedsampling)
|
||||
{
|
||||
|
@ -581,7 +638,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
logits[29961] = 0;
|
||||
}
|
||||
|
||||
id = llama_sample_top_p_top_k(llama_ctx_v1, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty);
|
||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
||||
|
||||
}
|
||||
else
|
||||
|
@ -601,7 +658,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
//gpt2 uses negative logits, so we cant zero it
|
||||
}
|
||||
|
||||
id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
|
||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
|
|
10
koboldcpp.py
10
koboldcpp.py
|
@ -32,6 +32,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("temperature", ctypes.c_float),
|
||||
("top_k", ctypes.c_int),
|
||||
("top_p", ctypes.c_float),
|
||||
("typical_p", ctypes.c_float),
|
||||
("tfs", ctypes.c_float),
|
||||
("rep_pen", ctypes.c_float),
|
||||
("rep_pen_range", ctypes.c_int),
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max)]
|
||||
|
@ -146,7 +148,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
|
||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -155,6 +157,8 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
|
|||
inputs.temperature = temperature
|
||||
inputs.top_k = top_k
|
||||
inputs.top_p = top_p
|
||||
inputs.typical_p = typical_p
|
||||
inputs.tfs = tfs
|
||||
inputs.rep_pen = rep_pen
|
||||
inputs.rep_pen_range = rep_pen_range
|
||||
inputs.seed = seed
|
||||
|
@ -297,6 +301,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=genparams.get('top_k', 200),
|
||||
top_p=genparams.get('top_p', 0.85),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
seed=-1,
|
||||
|
@ -311,6 +317,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=genparams.get('top_k', 200),
|
||||
top_p=genparams.get('top_p', 0.85),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
seed=-1,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue