Add --ignore-eos parameter (#181)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
084e2f0ec0
commit
50fae10d03
3 changed files with 14 additions and 1 deletions
10
main.cpp
10
main.cpp
|
@ -27,6 +27,8 @@
|
|||
#define ANSI_COLOR_RESET "\x1b[0m"
|
||||
#define ANSI_BOLD "\x1b[1m"
|
||||
|
||||
static const int EOS_TOKEN_ID = 2;
|
||||
|
||||
// determine number of model parts based on the dimension
|
||||
static const std::map<int, int> LLAMA_N_PARTS = {
|
||||
{ 4096, 1 },
|
||||
|
@ -956,6 +958,11 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (params.ignore_eos) {
|
||||
// set the logit of the eos token to zero to avoid sampling it
|
||||
logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
|
||||
}
|
||||
|
||||
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
|
@ -1055,7 +1062,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 2) {
|
||||
|
||||
if (embd.back() == EOS_TOKEN_ID) {
|
||||
if (params.interactive) {
|
||||
is_interacting = true;
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue