Add back top_k (#56)

* Add back top_k

* Update utils.cpp

* Update utils.h

---------

Co-authored-by: Bill Hamilton <bill.hamilton@shopify.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
beiller 2023-03-12 16:23:15 -04:00 committed by GitHub
parent eb062bb012
commit 02f0c6fe7f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 89 deletions

19
utils.h
View file

@ -19,7 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize
// sampling parameters
int32_t top_k = 40; // unused
int32_t top_k = 40;
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.30f;
@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
// TODO: not sure if this implementation is correct
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);
gpt_vocab::id llama_sample_top_p(
gpt_vocab::id llama_sample_top_p_top_k(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
int top_k,
double top_p,
double temp,
std::mt19937 & rng);
// filer to top K tokens from list of logits
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
//
// Quantization
//