From 66f2063da25b1da851801132f20f9cb675831740 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Tue, 3 Oct 2023 03:33:52 +0800 Subject: [PATCH] fixed floating point comparison issues --- llama.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 05b570bd1..5205be576 100644 --- a/llama.cpp +++ b/llama.cpp @@ -930,6 +930,8 @@ static const size_t kB = 1024; static const size_t MB = kB*kB; static const size_t GB = kB*kB*kB; +const double EPSILON = 1e-9; + struct llama_hparams { bool vocab_only; uint32_t n_vocab; @@ -948,7 +950,22 @@ struct llama_hparams { float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { - return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + if(this->vocab_only != other.vocab_only) return true; + if(this->n_vocab != other.n_vocab) return true; + if(this->n_ctx_train != other.n_ctx_train) return true; + if(this->n_embd != other.n_embd) return true; + if(this->n_head != other.n_head) return true; + if(this->n_head_kv != other.n_head_kv) return true; + if(this->n_layer != other.n_layer) return true; + if(this->n_rot != other.n_rot) return true; + if(this->n_ff != other.n_ff) return true; + + if(std::abs(this->f_norm_eps - other.f_norm_eps) > EPSILON) return true; + if(std::abs(this->f_norm_rms_eps - other.f_norm_rms_eps) > EPSILON) return true; + if(std::abs(this->rope_freq_base_train - other.rope_freq_base_train) > EPSILON) return true; + if(std::abs(this->rope_freq_scale_train - other.rope_freq_scale_train) > EPSILON) return true; + + return false; } uint32_t n_gqa() const {