diff --git a/model_adapter.cpp b/model_adapter.cpp index da9fa193e..514764ec0 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -133,28 +133,36 @@ void print_tok_vec(std::vector &embd) else if(vocabsiz==50257 || (vocabsiz>=49152&&vocabsiz<=49157)) //49152-6 is starcoder { fileformat = FileFormat::GPT2_1; - uint32_t temp; - fin.read((char *)&temp, sizeof(temp)); //ctx - fin.read((char *)&temp, sizeof(temp)); //n_embd - fin.read((char *)&temp, sizeof(temp)); //n_head + uint32_t temp, v1,v2,v3; + fin.read((char *)&v1, sizeof(temp)); //ctx + fin.read((char *)&v2, sizeof(temp)); //n_embd + fin.read((char *)&v3, sizeof(temp)); //n_head fin.read((char *)&temp, sizeof(temp)); //n_layer - fin.read((char *)&temp, sizeof(temp)); //f16 - const int32_t qntvr = temp / 1000; - temp %= 1000; - if (qntvr != 0) + if(vocabsiz==49152 && v1==4096 && v2==2560 && v3==32 && temp==32) { - if (qntvr == 1) - { - fileformat = FileFormat::GPT2_3; - } - else - { - fileformat = FileFormat::GPT2_4; - } + //special case, Stablecode Completion Alpha 3B + fileformat = FileFormat::NEOX_6; } - else if (temp != 0 && temp != 1) + else { - fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type + fin.read((char *)&temp, sizeof(temp)); //f16 + const int32_t qntvr = temp / 1000; + temp %= 1000; + if (qntvr != 0) + { + if (qntvr == 1) + { + fileformat = FileFormat::GPT2_3; + } + else + { + fileformat = FileFormat::GPT2_4; + } + } + else if (temp != 0 && temp != 1) + { + fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type + } } } else if(vocabsiz < 31998 || vocabsiz > 33000)