updated implementation for hparam comparison to handle inf and NaN

This commit is contained in:
l3utterfly 2023-10-04 18:01:50 +08:00
parent 66f2063da2
commit 16f45c4dec

View file

@ -123,6 +123,28 @@ static void replace_all(std::string & s, const std::string & search, const std::
} }
s = std::move(result); s = std::move(result);
} }
bool is_float_eq(float a, float b, float abs_tol) {
// Check for non-negative tolerance
if (abs_tol < 0.0) {
throw std::invalid_argument("Tolerance must be non-negative");
}
// Exact equality check
if (a == b) {
return true;
}
// Check for infinities
if (std::isinf(a) || std::isinf(b)) {
return false;
}
// Regular comparison using the provided absolute tolerance
double diff = std::fabs(b - a);
return (diff <= abs_tol);
}
#ifdef GGML_USE_CPU_HBM #ifdef GGML_USE_CPU_HBM
#include <hbwmalloc.h> #include <hbwmalloc.h>
#endif #endif
@ -930,8 +952,6 @@ static const size_t kB = 1024;
static const size_t MB = kB*kB; static const size_t MB = kB*kB;
static const size_t GB = kB*kB*kB; static const size_t GB = kB*kB*kB;
const double EPSILON = 1e-9;
struct llama_hparams { struct llama_hparams {
bool vocab_only; bool vocab_only;
uint32_t n_vocab; uint32_t n_vocab;
@ -949,23 +969,30 @@ struct llama_hparams {
float rope_freq_base_train; float rope_freq_base_train;
float rope_freq_scale_train; float rope_freq_scale_train;
bool operator!=(const llama_hparams & other) const { bool operator==(const llama_hparams & other) const {
if(this->vocab_only != other.vocab_only) return true; if (this->vocab_only != other.vocab_only) return false;
if(this->n_vocab != other.n_vocab) return true; if (this->n_vocab != other.n_vocab) return false;
if(this->n_ctx_train != other.n_ctx_train) return true; if (this->n_ctx_train != other.n_ctx_train) return false;
if(this->n_embd != other.n_embd) return true; if (this->n_embd != other.n_embd) return false;
if(this->n_head != other.n_head) return true; if (this->n_head != other.n_head) return false;
if(this->n_head_kv != other.n_head_kv) return true; if (this->n_head_kv != other.n_head_kv) return false;
if(this->n_layer != other.n_layer) return true; if (this->n_layer != other.n_layer) return false;
if(this->n_rot != other.n_rot) return true; if (this->n_rot != other.n_rot) return false;
if(this->n_ff != other.n_ff) return true; if (this->n_ff != other.n_ff) return false;
if(std::abs(this->f_norm_eps - other.f_norm_eps) > EPSILON) return true;
if(std::abs(this->f_norm_rms_eps - other.f_norm_rms_eps) > EPSILON) return true;
if(std::abs(this->rope_freq_base_train - other.rope_freq_base_train) > EPSILON) return true;
if(std::abs(this->rope_freq_scale_train - other.rope_freq_scale_train) > EPSILON) return true;
return false; const float EPSILON = 1e-9;
if (!is_float_eq(this->f_norm_eps, other.f_norm_eps, EPSILON)) return false;
if (!is_float_eq(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return false;
if (!is_float_eq(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return false;
if (!is_float_eq(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return false;
return true;
}
// implement != explicitly using the "==" implementation above so we don't get a warning about it
bool operator!=(const llama_hparams & other) const {
return !(*this == other);
} }
uint32_t n_gqa() const { uint32_t n_gqa() const {