added eos token id handling for starcoder models, as they use a different EOS ID
This commit is contained in:
parent
3649d35cca
commit
0971f83bca
1 changed files with 10 additions and 0 deletions
|
@ -1173,6 +1173,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||||
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
//special case, starcoder models use ID 0 for EOS
|
||||||
|
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
|
||||||
|
{
|
||||||
|
eosID = 0;
|
||||||
|
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||||
|
logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set the logit of the eos token (0) to minimum to avoid sampling it
|
// set the logit of the eos token (0) to minimum to avoid sampling it
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue