fixed broken typical sampler issues

This commit is contained in:
Concedo 2023-08-29 23:50:59 +08:00
parent cf5d918073
commit 380fa0f0ca

View file

@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
} }
} }
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
return (logits[topid] < 0 ? logits[topid] : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
}
int topid = std::min_element(logits, logits + size) - logits;
return (logits[topid] < 0 ? logits[topid] : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character static std::string RemoveBell(const std::string & input) //removes the bell character
{ {
std::string word2; std::string word2;
@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
eosID = llama_v3_token_eos(); eosID = llama_v3_token_eos();
} }
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
if (!unbanTokens) if (!unbanTokens)
{ {
// set the logit of the eos token (2) to -INF to avoid sampling it // set the logit of the eos token (2) to -INF to avoid sampling it
logitsPtr[eosID] = -INFINITY; logitsPtr[eosID] = lowestLogit;
} }
if(btsize>0) if(btsize>0)
{ {
for(int t=0;t<btsize;++t) for(int t=0;t<btsize;++t)
{ {
logitsPtr[banned_token_ids[t]]=-INFINITY; logitsPtr[banned_token_ids[t]]=lowestLogit;
} }
} }
} }
else else
{ {
logitsPtr = logits.data(); logitsPtr = logits.data();
float lowestLogit = LowestLogit(logits);
if (!unbanTokens) if (!unbanTokens)
{ {
//gpt2 uses negative logits, so we cant zero it //gpt2 uses negative logits, so we cant zero it
@ -1474,9 +1491,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::GPTJ_5) file_format == FileFormat::GPTJ_5)
{ {
eosID = 50256; eosID = 50256;
if(logits.size() > eosID) if(logits.size() > eosID)
{ {
logits[eosID] = -INFINITY; logits[eosID] = lowestLogit;
} }
else else
{ {
@ -1484,7 +1502,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4) if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
{ {
eosID = 0; eosID = 0;
logits[eosID] = -INFINITY; logits[eosID] = lowestLogit;
} }
} }
} }
@ -1502,7 +1521,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::MPT_1) file_format == FileFormat::MPT_1)
{ {
eosID = 0; eosID = 0;
logits[eosID] = -INFINITY; logits[eosID] = lowestLogit;
} }
} }
@ -1510,7 +1529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{ {
for (int t = 0; t < btsize; ++t) for (int t = 0; t < btsize; ++t)
{ {
logits[banned_token_ids[t]] = -INFINITY; logits[banned_token_ids[t]] = lowestLogit;
} }
} }
} }