added eos token id handling for starcoder models, as they use a different EOS ID

This commit is contained in:
Concedo 2023-06-15 22:57:14 +08:00
parent 3649d35cca
commit 0971f83bca

View file

@ -1173,6 +1173,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
int topid = std::min_element(logits.begin(),logits.end())-logits.begin(); int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
} }
else
{
//special case, starcoder models use ID 0 for EOS
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
{
eosID = 0;
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
}
}
} }
// set the logit of the eos token (0) to minimum to avoid sampling it // set the logit of the eos token (0) to minimum to avoid sampling it