fixed broken typical sampler issues

This commit is contained in:
Concedo 2023-08-29 23:50:59 +08:00
parent cf5d918073
commit 380fa0f0ca

View file

@ -366,6 +366,21 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
}
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
return (logits[topid] < 0 ? logits[topid] : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
}
int topid = std::min_element(logits, logits + size) - logits;
return (logits[topid] < 0 ? logits[topid] : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
@ -1442,23 +1457,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
eosID = llama_v3_token_eos();
}
float lowestLogit = LowestLogit(logitsPtr,n_vocab);
if (!unbanTokens)
{
// set the logit of the eos token (2) to -INF to avoid sampling it
logitsPtr[eosID] = -INFINITY;
logitsPtr[eosID] = lowestLogit;
}
if(btsize>0)
{
for(int t=0;t<btsize;++t)
{
logitsPtr[banned_token_ids[t]]=-INFINITY;
logitsPtr[banned_token_ids[t]]=lowestLogit;
}
}
}
else
{
logitsPtr = logits.data();
float lowestLogit = LowestLogit(logits);
if (!unbanTokens)
{
//gpt2 uses negative logits, so we cant zero it
@ -1474,9 +1491,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if(logits.size() > eosID)
{
logits[eosID] = -INFINITY;
logits[eosID] = lowestLogit;
}
else
{
@ -1484,7 +1502,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4)
{
eosID = 0;
logits[eosID] = -INFINITY;
logits[eosID] = lowestLogit;
}
}
}
@ -1502,7 +1521,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
file_format == FileFormat::MPT_1)
{
eosID = 0;
logits[eosID] = -INFINITY;
logits[eosID] = lowestLogit;
}
}
@ -1510,7 +1529,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
for (int t = 0; t < btsize; ++t)
{
logits[banned_token_ids[t]] = -INFINITY;
logits[banned_token_ids[t]] = lowestLogit;
}
}
}