Add repetition penalty (#20)

* Adding repeat penalization

* Update utils.h

* Update utils.cpp

* Numeric fix

Should probably still scale by temp even if penalized

* Update comments, more proper application

I see that numbers can go negative so a fix from a referenced commit

* Minor formatting

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
beiller 2023-03-12 05:27:42 -04:00 committed by GitHub
parent 702fddf5c5
commit 129c7d1ea8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 3 deletions

View file

@ -792,7 +792,7 @@ int main(int argc, char ** argv) {
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
}
printf("\n");
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p);
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("\n\n");
std::vector<gpt_vocab::id> embd;
@ -801,6 +801,10 @@ int main(int argc, char ** argv) {
size_t mem_per_token = 0;
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
int last_n_size = params.repeat_last_n;
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
@ -821,6 +825,7 @@ int main(int argc, char ** argv) {
// sample next token
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
const int n_vocab = model.hparams.n_vocab;
@ -829,7 +834,10 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
t_sample_us += ggml_time_us() - t_start_sample_us;
}
@ -840,6 +848,8 @@ int main(int argc, char ** argv) {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[k]);
if (embd.size() > params.n_batch) {
break;
}