diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 37392b8c4..79203e1a4 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector } } +static float LowestLogit(const std::vector & logits) +{ + int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); + return (logits[topid] < 0 ? logits[topid] : 0); +} +static float LowestLogit(const float *logits, size_t size) +{ + if (size == 0) { + // Handle the case of an empty array + return 0.0; + } + int topid = std::min_element(logits, logits + size) - logits; + return (logits[topid] < 0 ? logits[topid] : 0); +} + static std::string RemoveBell(const std::string & input) //removes the bell character { std::string word2; @@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o eosID = llama_v3_token_eos(); } + float lowestLogit = LowestLogit(logitsPtr,n_vocab); if (!unbanTokens) { // set the logit of the eos token (2) to -INF to avoid sampling it - logitsPtr[eosID] = -INFINITY; + logitsPtr[eosID] = lowestLogit; } if(btsize>0) { for(int t=0;t eosID) { - logits[eosID] = -INFINITY; + logits[eosID] = lowestLogit; } else { @@ -1484,7 +1502,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4) { eosID = 0; - logits[eosID] = -INFINITY; + logits[eosID] = lowestLogit; + } } } @@ -1502,7 +1521,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o file_format == FileFormat::MPT_1) { eosID = 0; - logits[eosID] = -INFINITY; + logits[eosID] = lowestLogit; } } @@ -1510,7 +1529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { for (int t = 0; t < btsize; ++t) { - logits[banned_token_ids[t]] = -INFINITY; + logits[banned_token_ids[t]] = lowestLogit; } } }