Update utils.cpp
This commit is contained in:
parent
3f6a118d6a
commit
78651d5792
1 changed files with 13 additions and 1 deletions
12
utils.cpp
12
utils.cpp
|
@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.top_p = std::stof(argv[++i]);
|
||||
} else if (arg == "--temp") {
|
||||
params.temp = std::stof(argv[++i]);
|
||||
} else if (arg == "--repeat_last_n") {
|
||||
params.repeat_last_n = std::stoi(argv[++i]);
|
||||
} else if (arg == "--repeat_penalty") {
|
||||
params.repeat_penalty = std::stof(argv[++i]);
|
||||
} else if (arg == "-b" || arg == "--batch_size") {
|
||||
params.n_batch = std::stoi(argv[++i]);
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
|
@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
||||
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
||||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
||||
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
|
||||
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
|
||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
|
@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
|||
gpt_vocab::id llama_sample_top_p(
|
||||
const gpt_vocab & vocab,
|
||||
const float * logits,
|
||||
std::vector<gpt_vocab::id> & last_n_tokens,
|
||||
double repeat_penalty,
|
||||
double top_p,
|
||||
double temp,
|
||||
std::mt19937 & rng) {
|
||||
|
@ -383,9 +391,13 @@ gpt_vocab::id llama_sample_top_p(
|
|||
{
|
||||
const double scale = 1.0/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) {
|
||||
logits_id.push_back(std::make_pair(logits[i]*(1/repeat_penalty), i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(
|
||||
logits_id.begin(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue