fixed a regression where a bad model was giving valid logits after library changes. now we run the eval through the model twice and compare logits. if they give the same logits for different inputs, model is broken

This commit is contained in:
Concedo 2023-04-29 18:25:17 +08:00
parent 5aa185f3f7
commit fe0e4de8e8

View file

@ -60,6 +60,24 @@ inline bool IsNanCheck(float f)
return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan.
}
inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr2)
{
int compareQty = 5;
if(arr1.size() < compareQty || arr2.size() < compareQty || arr1.size()!=arr2.size())
{
printf("\nError: Logit array sizes are bad!\n");
return false;
}
for(int i=0;i<compareQty;++i)
{
if(arr1[i]!=arr2[i])
{
return false;
}
}
return true;
}
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
{
ggml_time_init();
@ -212,7 +230,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
// determine the required inference memory per token:
legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
//if the logits are NAN, it means the model is incompatible
//if the logits are NAN or duplicated, it means the model is incompatible
if(logits.size()>0 && IsNanCheck(logits[0]))
{
printf("\nBad Logits detected! Retrying GPT-J model loading...");
@ -256,8 +274,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
// determine the required inference memory per token:
gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
//if the logits are NAN, it means the model is incompatible
if(logits.size()>0 && IsNanCheck(logits[0]))
//if the logits are NAN or duplicated, it means the model is incompatible
std::vector<float> oldlogits(logits);
//this is another hack because they change the library - we run the eval through the model
//twice and compare logits. if they give the same logits for different inputs, model is broken
gptj_eval(gptj_ctx_v2, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token);
if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits)))
{
printf("\nBad Logits detected! Retrying GPT-J model loading...");
ggml_free(gptj_ctx_v2.ctx);