updated implementation for hparam comparison to handle inf and NaN
This commit is contained in:
parent
66f2063da2
commit
16f45c4dec
1 changed files with 45 additions and 18 deletions
63
llama.cpp
63
llama.cpp
|
@ -123,6 +123,28 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
||||||
}
|
}
|
||||||
s = std::move(result);
|
s = std::move(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_float_eq(float a, float b, float abs_tol) {
|
||||||
|
// Check for non-negative tolerance
|
||||||
|
if (abs_tol < 0.0) {
|
||||||
|
throw std::invalid_argument("Tolerance must be non-negative");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact equality check
|
||||||
|
if (a == b) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for infinities
|
||||||
|
if (std::isinf(a) || std::isinf(b)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular comparison using the provided absolute tolerance
|
||||||
|
double diff = std::fabs(b - a);
|
||||||
|
return (diff <= abs_tol);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_CPU_HBM
|
#ifdef GGML_USE_CPU_HBM
|
||||||
#include <hbwmalloc.h>
|
#include <hbwmalloc.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -930,8 +952,6 @@ static const size_t kB = 1024;
|
||||||
static const size_t MB = kB*kB;
|
static const size_t MB = kB*kB;
|
||||||
static const size_t GB = kB*kB*kB;
|
static const size_t GB = kB*kB*kB;
|
||||||
|
|
||||||
const double EPSILON = 1e-9;
|
|
||||||
|
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
bool vocab_only;
|
bool vocab_only;
|
||||||
uint32_t n_vocab;
|
uint32_t n_vocab;
|
||||||
|
@ -949,23 +969,30 @@ struct llama_hparams {
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
|
|
||||||
|
bool operator==(const llama_hparams & other) const {
|
||||||
|
if (this->vocab_only != other.vocab_only) return false;
|
||||||
|
if (this->n_vocab != other.n_vocab) return false;
|
||||||
|
if (this->n_ctx_train != other.n_ctx_train) return false;
|
||||||
|
if (this->n_embd != other.n_embd) return false;
|
||||||
|
if (this->n_head != other.n_head) return false;
|
||||||
|
if (this->n_head_kv != other.n_head_kv) return false;
|
||||||
|
if (this->n_layer != other.n_layer) return false;
|
||||||
|
if (this->n_rot != other.n_rot) return false;
|
||||||
|
if (this->n_ff != other.n_ff) return false;
|
||||||
|
|
||||||
|
const float EPSILON = 1e-9;
|
||||||
|
|
||||||
|
if (!is_float_eq(this->f_norm_eps, other.f_norm_eps, EPSILON)) return false;
|
||||||
|
if (!is_float_eq(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return false;
|
||||||
|
if (!is_float_eq(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return false;
|
||||||
|
if (!is_float_eq(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// implement != explicitly using the "==" implementation above so we don't get a warning about it
|
||||||
bool operator!=(const llama_hparams & other) const {
|
bool operator!=(const llama_hparams & other) const {
|
||||||
if(this->vocab_only != other.vocab_only) return true;
|
return !(*this == other);
|
||||||
if(this->n_vocab != other.n_vocab) return true;
|
|
||||||
if(this->n_ctx_train != other.n_ctx_train) return true;
|
|
||||||
if(this->n_embd != other.n_embd) return true;
|
|
||||||
if(this->n_head != other.n_head) return true;
|
|
||||||
if(this->n_head_kv != other.n_head_kv) return true;
|
|
||||||
if(this->n_layer != other.n_layer) return true;
|
|
||||||
if(this->n_rot != other.n_rot) return true;
|
|
||||||
if(this->n_ff != other.n_ff) return true;
|
|
||||||
|
|
||||||
if(std::abs(this->f_norm_eps - other.f_norm_eps) > EPSILON) return true;
|
|
||||||
if(std::abs(this->f_norm_rms_eps - other.f_norm_rms_eps) > EPSILON) return true;
|
|
||||||
if(std::abs(this->rope_freq_base_train - other.rope_freq_base_train) > EPSILON) return true;
|
|
||||||
if(std::abs(this->rope_freq_scale_train - other.rope_freq_scale_train) > EPSILON) return true;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_gqa() const {
|
uint32_t n_gqa() const {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue