all : be more strict about converting float to double (#458)
* Be more strict about converting float to double * Test equivalence of round, SILU implementations Test module is commented out in CMakeLists.txt because the tests may take a long time, depending on how much the compiler optimizes. * Fix softmax in perplexity.cpp * all : prefer float over double where appropriate * perplexity : add <cmath> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
20e1e84884
commit
436e561931
11 changed files with 185 additions and 117 deletions
138
ggml.c
138
ggml.c
|
@ -150,10 +150,10 @@ typedef double ggml_float;
|
|||
//
|
||||
#include <arm_neon.h>
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
|
||||
|
||||
#define GGML_FP16_TO_FP32(x) (x)
|
||||
#define GGML_FP16_TO_FP32(x) ((float) (x))
|
||||
#define GGML_FP32_TO_FP16(x) (x)
|
||||
|
||||
#else
|
||||
|
@ -322,7 +322,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|||
// note: do not use these inside ggml.c
|
||||
// these are meant to be used via the ggml.h API
|
||||
float ggml_fp16_to_fp32(ggml_fp16_t x) {
|
||||
return GGML_FP16_TO_FP32(x);
|
||||
return (float) GGML_FP16_TO_FP32(x);
|
||||
}
|
||||
|
||||
ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
||||
|
@ -488,8 +488,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|||
const float v0 = x[i*QK + l + 0]*id;
|
||||
const float v1 = x[i*QK + l + 1]*id;
|
||||
|
||||
const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
|
||||
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
|
||||
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
||||
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
||||
|
||||
assert(vi0 >= 0 && vi0 < 16);
|
||||
assert(vi1 >= 0 && vi1 < 16);
|
||||
|
@ -566,7 +566,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
||||
|
||||
const float d = amax / ((1 << 3) - 1);
|
||||
const float id = d ? 1.0/d : 0.0;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
|
@ -716,8 +716,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|||
const float v0 = (x[i*QK + l + 0] - min)*id;
|
||||
const float v1 = (x[i*QK + l + 1] - min)*id;
|
||||
|
||||
const uint8_t vi0 = round(v0);
|
||||
const uint8_t vi1 = round(v1);
|
||||
const uint8_t vi0 = roundf(v0);
|
||||
const uint8_t vi1 = roundf(v1);
|
||||
|
||||
assert(vi0 >= 0 && vi0 < 16);
|
||||
assert(vi1 >= 0 && vi1 < 16);
|
||||
|
@ -1001,7 +1001,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|||
} \
|
||||
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
||||
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
||||
res = vaddvq_f32(vaddq_f32(t0, t1)); \
|
||||
res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
|
||||
}
|
||||
|
||||
#define GGML_F16_VEC GGML_F16x8
|
||||
|
@ -1437,9 +1437,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co
|
|||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
|
||||
inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
|
||||
ggml_float sumf = 0.0;
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
float sumf = 0.0f;
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
|
@ -1465,8 +1464,9 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|||
}
|
||||
#else
|
||||
// scalar
|
||||
ggml_float sumf = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += x[i]*y[i];
|
||||
sumf += (ggml_float)(x[i]*y[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1529,11 +1529,11 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
|
||||
sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
|
||||
sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -1549,7 +1549,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||
const block_q4_0 * restrict x = vx;
|
||||
const block_q4_0 * restrict y = vy;
|
||||
|
||||
float sumf = 0.0;
|
||||
ggml_float sumf = 0.0;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
float sum0 = 0.0f;
|
||||
|
@ -1644,7 +1644,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||
#endif
|
||||
}
|
||||
|
||||
sumf = sum0 + sum1;
|
||||
sumf = (ggml_float)(sum0 + sum1);
|
||||
#elif defined(__AVX512F__)
|
||||
// Initialize accumulator with zeros
|
||||
__m512 acc0 = _mm512_setzero_ps();
|
||||
|
@ -1972,13 +1972,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
|
|||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
|
||||
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
|
||||
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -2049,19 +2049,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
|
||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
|
||||
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
|
||||
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
|
||||
static const ggml_float GELU_COEF_A = 0.044715;
|
||||
static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
inline static float ggml_gelu_f32(float x) {
|
||||
return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
|
@ -2090,7 +2090,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
|||
|
||||
// Sigmoid Linear Unit (SiLU) function
|
||||
inline static float ggml_silu_f32(float x) {
|
||||
return x/(1.0 + exp(-x));
|
||||
return x/(1.0f + expf(-x));
|
||||
}
|
||||
|
||||
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
|
@ -2121,7 +2121,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|||
#ifndef GGML_USE_ACCELERATE
|
||||
ggml_float sum = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sum += x[i];
|
||||
sum += (ggml_float)x[i];
|
||||
}
|
||||
*s = sum;
|
||||
#else
|
||||
|
@ -2131,7 +2131,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|||
|
||||
inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
||||
#ifndef GGML_USE_ACCELERATE
|
||||
ggml_float max = -INFINITY;
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
max = MAX(max, x[i]);
|
||||
}
|
||||
|
@ -2141,7 +2141,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
|
|||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
|
||||
inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
|
||||
ggml_vec_norm_f32(n, s, x);
|
||||
*s = 1.f/(*s);
|
||||
}
|
||||
|
||||
//
|
||||
// logging
|
||||
|
@ -2540,7 +2543,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
|
||||
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
||||
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
|
||||
table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f));
|
||||
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
|
||||
}
|
||||
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
@ -5583,7 +5586,7 @@ static void ggml_compute_forward_norm_f32(
|
|||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const ggml_float eps = 1e-5f; // TODO: make this a parameter
|
||||
const float eps = 1e-5f; // TODO: make this a parameter
|
||||
|
||||
// TODO: optimize
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -5591,23 +5594,24 @@ static void ggml_compute_forward_norm_f32(
|
|||
for (int i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float mean = 0.0;
|
||||
ggml_float sum = 0.0;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
mean += x[i00];
|
||||
sum += (ggml_float)x[i00];
|
||||
}
|
||||
|
||||
mean /= ne00;
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_float sum2 = 0.0;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
ggml_float v = x[i00] - mean;
|
||||
float v = x[i00] - mean;
|
||||
y[i00] = v;
|
||||
sum2 += v*v;
|
||||
sum2 += (ggml_float)(v*v);
|
||||
}
|
||||
|
||||
const float scale = 1.0/sqrt(sum2/ne00 + eps);
|
||||
float variance = sum2/ne00;
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
|
@ -5665,7 +5669,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const ggml_float eps = 1e-6f; // TODO: make this a parameter
|
||||
const float eps = 1e-6f; // TODO: make this a parameter
|
||||
|
||||
// TODO: optimize
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -5673,12 +5677,12 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
for (int i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float mean = 0.0;
|
||||
ggml_float sum = 0.0;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
mean += x[i00] * x[i00];
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
}
|
||||
|
||||
mean /= ne00;
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
|
@ -5687,7 +5691,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
// y[i00] = x[i00];
|
||||
// }
|
||||
|
||||
const float scale = 1.0/sqrt(mean + eps);
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
|
@ -6913,12 +6917,12 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
|
||||
memcpy(&scvt, &s, sizeof(scvt));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
||||
sum += val;
|
||||
sum += (ggml_float)val;
|
||||
p[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
assert(sum > 0.0f);
|
||||
assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(nc, p, sum);
|
||||
|
@ -6994,16 +6998,16 @@ static void ggml_compute_forward_rope_f32(
|
|||
const int p = (mode == 0 ? n_past + i2 : i2);
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const double theta = pow(10000.0, ((double)-i0)/n_dims);
|
||||
const float theta = powf(10000.0, ((float)-i0)/n_dims);
|
||||
|
||||
const double cos_theta = cos(p*theta);
|
||||
const double sin_theta = sin(p*theta);
|
||||
const float cos_theta = cosf(p*theta);
|
||||
const float sin_theta = sinf(p*theta);
|
||||
|
||||
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
double x0 = src[0];
|
||||
double x1 = src[1];
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
|
@ -7050,16 +7054,16 @@ static void ggml_compute_forward_rope_f16(
|
|||
const int p = (mode == 0 ? n_past + i2 : i2);
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
const double theta = pow(10000.0, ((double)-i0)/n_dims);
|
||||
const float theta = powf(10000.0, ((float)-i0)/n_dims);
|
||||
|
||||
const double cos_theta = cos(p*theta);
|
||||
const double sin_theta = sin(p*theta);
|
||||
const float cos_theta = cosf(p*theta);
|
||||
const float sin_theta = sinf(p*theta);
|
||||
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
double x0 = ggml_fp16_to_fp32(src[0]);
|
||||
double x1 = ggml_fp16_to_fp32(src[1]);
|
||||
const float x0 = ggml_fp16_to_fp32(src[0]);
|
||||
const float x1 = ggml_fp16_to_fp32(src[1]);
|
||||
|
||||
dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
|
||||
dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
|
||||
|
@ -7735,7 +7739,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float scale = 1.0/sqrt((double) D);
|
||||
const float scale = 1.0f/sqrtf(D);
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
|
@ -7782,7 +7786,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(M, &max, S);
|
||||
|
||||
float sum = 0.0f;
|
||||
ggml_float sum = 0.0;
|
||||
{
|
||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||
max = -max;
|
||||
|
@ -7803,7 +7807,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||
sump[j] += val;
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
}
|
||||
}
|
||||
|
@ -7815,7 +7819,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
#endif
|
||||
}
|
||||
|
||||
assert(sum > 0.0f);
|
||||
assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(M, S, sum);
|
||||
|
@ -7944,7 +7948,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float scale = 1.0/sqrt((double) D);
|
||||
const float scale = 1.0f/sqrtf(D);
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
|
@ -8008,7 +8012,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(M, &max, S);
|
||||
|
||||
float sum = 0.0f;
|
||||
ggml_float sum = 0.0;
|
||||
{
|
||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||
max = -max;
|
||||
|
@ -8029,7 +8033,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||
sump[j] += val;
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
}
|
||||
}
|
||||
|
@ -8041,7 +8045,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
#endif
|
||||
}
|
||||
|
||||
assert(sum > 0.0f);
|
||||
assert(sum > 0.0);
|
||||
|
||||
sum = 1.0/sum;
|
||||
ggml_vec_scale_f32(M, S, sum);
|
||||
|
@ -9566,7 +9570,7 @@ label=\"%d [%d, %d] | <x>%s",
|
|||
fprintf(fp, " \"%p\" [ \
|
||||
style = filled; fillcolor = %s; shape = record; \
|
||||
label=\"<x>%.1e\"; ]\n",
|
||||
(void *) node, color, ggml_get_f32_1d(node, 0));
|
||||
(void *) node, color, (double)ggml_get_f32_1d(node, 0));
|
||||
} else {
|
||||
fprintf(fp, " \"%p\" [ \
|
||||
style = filled; fillcolor = %s; shape = record; \
|
||||
|
@ -9804,7 +9808,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
if (params.past <= t) {
|
||||
const float rate = (pf[t%params.past] - fx)/fx;
|
||||
|
||||
if (fabs(rate) < params.delta) {
|
||||
if (fabsf(rate) < params.delta) {
|
||||
return GGML_OPT_OK;
|
||||
}
|
||||
}
|
||||
|
@ -9883,7 +9887,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
const float dec = 0.5f;
|
||||
const float inc = 2.1f;
|
||||
|
||||
if (*step <= 0.) {
|
||||
if (*step <= 0.f) {
|
||||
return GGML_LINESEARCH_INVALID_PARAMETERS;
|
||||
}
|
||||
|
||||
|
@ -9971,7 +9975,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
struct ggml_cgraph * gb) {
|
||||
if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
|
||||
params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
|
||||
if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
|
||||
if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
|
||||
return GGML_OPT_INVALID_WOLFE;
|
||||
}
|
||||
}
|
||||
|
@ -10092,8 +10096,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
|
||||
GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
||||
|
||||
if (xnorm < 1.0) {
|
||||
xnorm = 1.0;
|
||||
if (xnorm < 1.0f) {
|
||||
xnorm = 1.0f;
|
||||
}
|
||||
if (gnorm/xnorm <= params.lbfgs.eps) {
|
||||
// converged
|
||||
|
@ -10106,7 +10110,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
if (params.past <= k) {
|
||||
const float rate = (pf[k%params.past] - fx)/fx;
|
||||
|
||||
if (fabs(rate) < params.delta) {
|
||||
if (fabsf(rate) < params.delta) {
|
||||
return GGML_OPT_OK;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue