From 0971f83bca2266fba477932c2285c8c8600b5bfb Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 15 Jun 2023 22:57:14 +0800 Subject: [PATCH] added eos token id handling for starcoder models, as they use a different EOS ID --- gpttype_adapter.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 775f29bb6..bbd405757 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1173,6 +1173,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o int topid = std::min_element(logits.begin(),logits.end())-logits.begin(); logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); } + else + { + //special case, starcoder models use ID 0 for EOS + if (file_format == FileFormat::GPT2_3 || file_format == FileFormat::GPT2_4) + { + eosID = 0; + int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); + logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); + } + } } // set the logit of the eos token (0) to minimum to avoid sampling it