diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 362080b09..b3730ce50 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -60,6 +60,24 @@ inline bool IsNanCheck(float f) return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan. } +inline bool LogitsDuplicated(std::vector & arr1, std::vector & arr2) +{ + int compareQty = 5; + if(arr1.size() < compareQty || arr2.size() < compareQty || arr1.size()!=arr2.size()) + { + printf("\nError: Logit array sizes are bad!\n"); + return false; + } + for(int i=0;i0 && IsNanCheck(logits[0])) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); @@ -256,8 +274,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in // determine the required inference memory per token: gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - //if the logits are NAN, it means the model is incompatible - if(logits.size()>0 && IsNanCheck(logits[0])) + //if the logits are NAN or duplicated, it means the model is incompatible + std::vector oldlogits(logits); + + //this is another hack because they change the library - we run the eval through the model + //twice and compare logits. if they give the same logits for different inputs, model is broken + gptj_eval(gptj_ctx_v2, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token); + + if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits))) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); ggml_free(gptj_ctx_v2.ctx);