added eos token id handling for starcoder models, as they use a different EOS ID
This commit is contained in:
parent
3649d35cca
commit
0971f83bca
1 changed files with 10 additions and 0 deletions
|
@ -1173,6 +1173,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
//special case, starcoder models use ID 0 for EOS
|
||||
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
|
||||
{
|
||||
eosID = 0;
|
||||
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set the logit of the eos token (0) to minimum to avoid sampling it
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue