From fe0e4de8e82ba5f64e56ba11ca055bb7604722b1 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 29 Apr 2023 18:25:17 +0800 Subject: [PATCH] fixed a regression where a bad model was giving valid logits after library changes. now we run the eval through the model twice and compare logits. if they give the same logits for different inputs, model is broken --- gpttype_adapter.cpp | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 362080b09..b3730ce50 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -60,6 +60,24 @@ inline bool IsNanCheck(float f) return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan. } +inline bool LogitsDuplicated(std::vector & arr1, std::vector & arr2) +{ + int compareQty = 5; + if(arr1.size() < compareQty || arr2.size() < compareQty || arr1.size()!=arr2.size()) + { + printf("\nError: Logit array sizes are bad!\n"); + return false; + } + for(int i=0;i0 && IsNanCheck(logits[0])) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); @@ -256,8 +274,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in // determine the required inference memory per token: gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - //if the logits are NAN, it means the model is incompatible - if(logits.size()>0 && IsNanCheck(logits[0])) + //if the logits are NAN or duplicated, it means the model is incompatible + std::vector oldlogits(logits); + + //this is another hack because they change the library - we run the eval through the model + //twice and compare logits. if they give the same logits for different inputs, model is broken + gptj_eval(gptj_ctx_v2, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token); + + if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits))) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); ggml_free(gptj_ctx_v2.ctx);