fixed floating point comparison issues

This commit is contained in:
l3utterfly 2023-10-03 03:33:52 +08:00
parent 9476b01226
commit 66f2063da2

View file

@ -930,6 +930,8 @@ static const size_t kB = 1024;
static const size_t MB = kB*kB;
static const size_t GB = kB*kB*kB;
const double EPSILON = 1e-9;
struct llama_hparams {
bool vocab_only;
uint32_t n_vocab;
@ -948,7 +950,22 @@ struct llama_hparams {
float rope_freq_scale_train;
bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
if(this->vocab_only != other.vocab_only) return true;
if(this->n_vocab != other.n_vocab) return true;
if(this->n_ctx_train != other.n_ctx_train) return true;
if(this->n_embd != other.n_embd) return true;
if(this->n_head != other.n_head) return true;
if(this->n_head_kv != other.n_head_kv) return true;
if(this->n_layer != other.n_layer) return true;
if(this->n_rot != other.n_rot) return true;
if(this->n_ff != other.n_ff) return true;
if(std::abs(this->f_norm_eps - other.f_norm_eps) > EPSILON) return true;
if(std::abs(this->f_norm_rms_eps - other.f_norm_rms_eps) > EPSILON) return true;
if(std::abs(this->rope_freq_base_train - other.rope_freq_base_train) > EPSILON) return true;
if(std::abs(this->rope_freq_scale_train - other.rope_freq_scale_train) > EPSILON) return true;
return false;
}
uint32_t n_gqa() const {