fixed broken typical sampler issues
This commit is contained in:
parent
cf5d918073
commit
380fa0f0ca
1 changed files with 25 additions and 6 deletions
|
@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float LowestLogit(const std::vector<float> & logits)
|
||||||
|
{
|
||||||
|
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||||
|
return (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
|
}
|
||||||
|
static float LowestLogit(const float *logits, size_t size)
|
||||||
|
{
|
||||||
|
if (size == 0) {
|
||||||
|
// Handle the case of an empty array
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
int topid = std::min_element(logits, logits + size) - logits;
|
||||||
|
return (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
|
}
|
||||||
|
|
||||||
static std::string RemoveBell(const std::string & input) //removes the bell character
|
static std::string RemoveBell(const std::string & input) //removes the bell character
|
||||||
{
|
{
|
||||||
std::string word2;
|
std::string word2;
|
||||||
|
@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
eosID = llama_v3_token_eos();
|
eosID = llama_v3_token_eos();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
|
||||||
if (!unbanTokens)
|
if (!unbanTokens)
|
||||||
{
|
{
|
||||||
// set the logit of the eos token (2) to -INF to avoid sampling it
|
// set the logit of the eos token (2) to -INF to avoid sampling it
|
||||||
logitsPtr[eosID] = -INFINITY;
|
logitsPtr[eosID] = lowestLogit;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(btsize>0)
|
if(btsize>0)
|
||||||
{
|
{
|
||||||
for(int t=0;t<btsize;++t)
|
for(int t=0;t<btsize;++t)
|
||||||
{
|
{
|
||||||
logitsPtr[banned_token_ids[t]]=-INFINITY;
|
logitsPtr[banned_token_ids[t]]=lowestLogit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
logitsPtr = logits.data();
|
logitsPtr = logits.data();
|
||||||
|
float lowestLogit = LowestLogit(logits);
|
||||||
if (!unbanTokens)
|
if (!unbanTokens)
|
||||||
{
|
{
|
||||||
//gpt2 uses negative logits, so we cant zero it
|
//gpt2 uses negative logits, so we cant zero it
|
||||||
|
@ -1474,9 +1491,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
file_format == FileFormat::GPTJ_5)
|
file_format == FileFormat::GPTJ_5)
|
||||||
{
|
{
|
||||||
eosID = 50256;
|
eosID = 50256;
|
||||||
|
|
||||||
if(logits.size() > eosID)
|
if(logits.size() > eosID)
|
||||||
{
|
{
|
||||||
logits[eosID] = -INFINITY;
|
logits[eosID] = lowestLogit;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -1484,7 +1502,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
|
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
|
||||||
{
|
{
|
||||||
eosID = 0;
|
eosID = 0;
|
||||||
logits[eosID] = -INFINITY;
|
logits[eosID] = lowestLogit;
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1502,7 +1521,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
file_format == FileFormat::MPT_1)
|
file_format == FileFormat::MPT_1)
|
||||||
{
|
{
|
||||||
eosID = 0;
|
eosID = 0;
|
||||||
logits[eosID] = -INFINITY;
|
logits[eosID] = lowestLogit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1510,7 +1529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
for (int t = 0; t < btsize; ++t)
|
for (int t = 0; t < btsize; ++t)
|
||||||
{
|
{
|
||||||
logits[banned_token_ids[t]] = -INFINITY;
|
logits[banned_token_ids[t]] = lowestLogit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue