fixed a regression where a bad model was giving valid logits after library changes. now we run the eval through the model twice and compare logits. if they give the same logits for different inputs, model is broken
This commit is contained in:
parent
5aa185f3f7
commit
fe0e4de8e8
1 changed files with 27 additions and 3 deletions
|
@ -60,6 +60,24 @@ inline bool IsNanCheck(float f)
|
|||
return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan.
|
||||
}
|
||||
|
||||
inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr2)
|
||||
{
|
||||
int compareQty = 5;
|
||||
if(arr1.size() < compareQty || arr2.size() < compareQty || arr1.size()!=arr2.size())
|
||||
{
|
||||
printf("\nError: Logit array sizes are bad!\n");
|
||||
return false;
|
||||
}
|
||||
for(int i=0;i<compareQty;++i)
|
||||
{
|
||||
if(arr1[i]!=arr2[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
||||
{
|
||||
ggml_time_init();
|
||||
|
@ -212,7 +230,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
// determine the required inference memory per token:
|
||||
legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
|
||||
//if the logits are NAN, it means the model is incompatible
|
||||
//if the logits are NAN or duplicated, it means the model is incompatible
|
||||
if(logits.size()>0 && IsNanCheck(logits[0]))
|
||||
{
|
||||
printf("\nBad Logits detected! Retrying GPT-J model loading...");
|
||||
|
@ -256,8 +274,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
// determine the required inference memory per token:
|
||||
gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
|
||||
//if the logits are NAN, it means the model is incompatible
|
||||
if(logits.size()>0 && IsNanCheck(logits[0]))
|
||||
//if the logits are NAN or duplicated, it means the model is incompatible
|
||||
std::vector<float> oldlogits(logits);
|
||||
|
||||
//this is another hack because they change the library - we run the eval through the model
|
||||
//twice and compare logits. if they give the same logits for different inputs, model is broken
|
||||
gptj_eval(gptj_ctx_v2, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token);
|
||||
|
||||
if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits)))
|
||||
{
|
||||
printf("\nBad Logits detected! Retrying GPT-J model loading...");
|
||||
ggml_free(gptj_ctx_v2.ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue