fixed broken typical sampler issues
This commit is contained in:
parent
cf5d918073
commit
380fa0f0ca
1 changed files with 25 additions and 6 deletions
|
@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
|
|||
}
|
||||
}
|
||||
|
||||
static float LowestLogit(const std::vector<float> & logits)
|
||||
{
|
||||
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||
return (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
static float LowestLogit(const float *logits, size_t size)
|
||||
{
|
||||
if (size == 0) {
|
||||
// Handle the case of an empty array
|
||||
return 0.0;
|
||||
}
|
||||
int topid = std::min_element(logits, logits + size) - logits;
|
||||
return (logits[topid] < 0 ? logits[topid] : 0);
|
||||
}
|
||||
|
||||
static std::string RemoveBell(const std::string & input) //removes the bell character
|
||||
{
|
||||
std::string word2;
|
||||
|
@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
eosID = llama_v3_token_eos();
|
||||
}
|
||||
|
||||
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
|
||||
if (!unbanTokens)
|
||||
{
|
||||
// set the logit of the eos token (2) to -INF to avoid sampling it
|
||||
logitsPtr[eosID] = -INFINITY;
|
||||
logitsPtr[eosID] = lowestLogit;
|
||||
}
|
||||
|
||||
if(btsize>0)
|
||||
{
|
||||
for(int t=0;t<btsize;++t)
|
||||
{
|
||||
logitsPtr[banned_token_ids[t]]=-INFINITY;
|
||||
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtr = logits.data();
|
||||
float lowestLogit = LowestLogit(logits);
|
||||
if (!unbanTokens)
|
||||
{
|
||||
//gpt2 uses negative logits, so we cant zero it
|
||||
|
@ -1474,9 +1491,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::GPTJ_5)
|
||||
{
|
||||
eosID = 50256;
|
||||
|
||||
if(logits.size() > eosID)
|
||||
{
|
||||
logits[eosID] = -INFINITY;
|
||||
logits[eosID] = lowestLogit;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1484,7 +1502,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
|
||||
{
|
||||
eosID = 0;
|
||||
logits[eosID] = -INFINITY;
|
||||
logits[eosID] = lowestLogit;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1502,7 +1521,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::MPT_1)
|
||||
{
|
||||
eosID = 0;
|
||||
logits[eosID] = -INFINITY;
|
||||
logits[eosID] = lowestLogit;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1510,7 +1529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
for (int t = 0; t < btsize; ++t)
|
||||
{
|
||||
logits[banned_token_ids[t]] = -INFINITY;
|
||||
logits[banned_token_ids[t]] = lowestLogit;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue