diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 775f29bb6..bbd405757 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1173,6 +1173,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o int topid = std::min_element(logits.begin(),logits.end())-logits.begin(); logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); } + else + { + //special case, starcoder models use ID 0 for EOS + if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4) + { + eosID = 0; + int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); + logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); + } + } } // set the logit of the eos token (0) to minimum to avoid sampling it