From 54b75a77fb58292c6dde89d213e82fae5b171d68 Mon Sep 17 00:00:00 2001 From: Stephan Walter Date: Fri, 24 Mar 2023 10:26:44 +0100 Subject: [PATCH] Be more strict about converting float to double --- CMakeLists.txt | 5 +- Makefile | 6 +- examples/common.cpp | 6 +- examples/embedding/embedding.cpp | 2 +- examples/main/main.cpp | 11 +-- examples/quantize/quantize.cpp | 4 +- ggml.c | 112 ++++++++++++++++--------------- llama.cpp | 28 ++++---- 8 files changed, 91 insertions(+), 83 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 27a222a16..241be4c15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,17 +124,18 @@ if (LLAMA_ALL_WARNINGS) -Wall -Wextra -Wpedantic - -Wshadow -Wcast-qual + -Wdouble-promotion + -Wshadow -Wstrict-prototypes -Wpointer-arith - -Wno-unused-function ) set(cxx_flags -Wall -Wextra -Wpedantic -Wcast-qual + -Wdouble-promotion ) else() # todo : msvc diff --git a/Makefile b/Makefile index 973b2951b..d419d9e61 100644 --- a/Makefile +++ b/Makefile @@ -31,8 +31,10 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC \ + -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC \ + -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion LDFLAGS = # OS specific diff --git a/examples/common.cpp b/examples/common.cpp index 880ebe9a2..af3ad9eb7 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -215,13 +215,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); - fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); - fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); - fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); + fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index d397f35fd..c7eb81cd5 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -89,7 +89,7 @@ int main(int argc, char ** argv) { const auto embeddings = llama_get_embeddings(ctx); for (int i = 0; i < n_embd; i++) { - printf("%f ", embeddings[i]); + printf("%f ", (double)embeddings[i]); } printf("\n"); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d5ab2cf75..3836562e8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -209,7 +209,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } } - fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + (double)params.temp, params.top_k, (double)params.top_p, params.repeat_last_n, (double)params.repeat_penalty); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -274,10 +275,10 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token - const float top_k = params.top_k; - const float top_p = params.top_p; - const float temp = params.temp; - const float repeat_penalty = params.repeat_penalty; + const int top_k = params.top_k; + const double top_p = (double)params.top_p; + const double temp = (double)params.temp; + const double repeat_penalty = (double)params.repeat_penalty; llama_token id = 0; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 3888ff587..b444328ac 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -50,8 +50,8 @@ int main(int argc, char ** argv) { const int64_t t_main_end_us = ggml_time_us(); printf("\n"); - printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); } return 0; diff --git a/ggml.c b/ggml.c index bf8ec8ab2..6838aa338 100644 --- a/ggml.c +++ b/ggml.c @@ -488,8 +488,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r const float v0 = x[i*QK + l + 0]*id; const float v1 = x[i*QK + l + 1]*id; - const uint8_t vi0 = ((int8_t) (round(v0))) + 8; - const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + const uint8_t vi0 = (int8_t)roundf(v0) + 8; + const uint8_t vi1 = (int8_t)roundf(v1) + 8; assert(vi0 >= 0 && vi0 < 16); assert(vi1 >= 0 && vi1 < 16); @@ -716,8 +716,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int const float v0 = (x[i*QK + l + 0] - min)*id; const float v1 = (x[i*QK + l + 1] - min)*id; - const uint8_t vi0 = round(v0); - const uint8_t vi1 = round(v1); + const uint8_t vi0 = roundf(v0); + const uint8_t vi1 = roundf(v1); assert(vi0 >= 0 && vi0 < 16); assert(vi1 >= 0 && vi1 < 16); @@ -1437,9 +1437,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { - ggml_float sumf = 0.0; - #ifdef GGML_SIMD + float sumf = 0.0f; const int np = (n & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -1465,8 +1464,9 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float } #else // scalar + ggml_float sumf = 0.0; for (int i = 0; i < n; ++i) { - sumf += x[i]*y[i]; + sumf += (ggml_float)(x[i]*y[i]); } #endif @@ -1505,7 +1505,7 @@ static inline __m512 dot_q4_0_oneblock_avx512( #endif inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { - ggml_float sumf = 0.0; + float sumf = 0.0f; #if defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -1936,7 +1936,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + float sumf[GGML_VEC_DOT_UNROLL] = { 0.0f }; ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; @@ -2049,19 +2049,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } -inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); } +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -static const ggml_float GELU_COEF_A = 0.044715; -static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { - return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { @@ -2090,7 +2090,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { - return x/(1.0 + exp(-x)); + return x/(1.0f + expf(-x)); } inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { @@ -2121,7 +2121,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; for (int i = 0; i < n; ++i) { - sum += x[i]; + sum += (ggml_float)x[i]; } *s = sum; #else @@ -2131,7 +2131,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE - ggml_float max = -INFINITY; + float max = -INFINITY; for (int i = 0; i < n; ++i) { max = MAX(max, x[i]); } @@ -2141,7 +2141,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #endif } -inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} // // logging @@ -2540,7 +2543,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); - table_exp_f16[i] = GGML_FP32_TO_FP16(exp(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); @@ -5583,7 +5586,7 @@ static void ggml_compute_forward_norm_f32( const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const ggml_float eps = 1e-5f; // TODO: make this a parameter + const float eps = 1e-5f; // TODO: make this a parameter // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { @@ -5591,23 +5594,24 @@ static void ggml_compute_forward_norm_f32( for (int i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float mean = 0.0; + ggml_float sum = 0.0; for (int i00 = 0; i00 < ne00; i00++) { - mean += x[i00]; + sum += (ggml_float)x[i00]; } - mean /= ne00; + float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); ggml_float sum2 = 0.0; for (int i00 = 0; i00 < ne00; i00++) { - ggml_float v = x[i00] - mean; + float v = x[i00] - mean; y[i00] = v; - sum2 += v*v; + sum2 += (ggml_float)(v*v); } - const float scale = 1.0/sqrt(sum2/ne00 + eps); + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); ggml_vec_scale_f32(ne00, y, scale); } @@ -5665,7 +5669,7 @@ static void ggml_compute_forward_rms_norm_f32( const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; - const ggml_float eps = 1e-6f; // TODO: make this a parameter + const float eps = 1e-6f; // TODO: make this a parameter // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { @@ -5673,12 +5677,12 @@ static void ggml_compute_forward_rms_norm_f32( for (int i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float mean = 0.0; + ggml_float sum = 0.0; for (int i00 = 0; i00 < ne00; i00++) { - mean += x[i00] * x[i00]; + sum += (ggml_float)(x[i00] * x[i00]); } - mean /= ne00; + float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); @@ -5687,7 +5691,7 @@ static void ggml_compute_forward_rms_norm_f32( // y[i00] = x[i00]; // } - const float scale = 1.0/sqrt(mean + eps); + const float scale = 1.0f/sqrtf(mean + eps); ggml_vec_scale_f32(ne00, y, scale); } @@ -6913,12 +6917,12 @@ static void ggml_compute_forward_soft_max_f32( ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); - sum += val; + sum += (ggml_float)val; p[i] = val; } } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(nc, p, sum); @@ -7002,8 +7006,8 @@ static void ggml_compute_forward_rope_f32( const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - double x0 = src[0]; - double x1 = src[1]; + double x0 = (double)src[0]; + double x1 = (double)src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -7052,14 +7056,14 @@ static void ggml_compute_forward_rope_f16( for (int i0 = 0; i0 < n_dims; i0 += 2) { const double theta = pow(10000.0, ((double)-i0)/n_dims); - const double cos_theta = cos(p*theta); - const double sin_theta = sin(p*theta); + const float cos_theta = cos(p*theta); + const float sin_theta = sin(p*theta); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - double x0 = ggml_fp16_to_fp32(src[0]); - double x1 = ggml_fp16_to_fp32(src[1]); + float x0 = ggml_fp16_to_fp32(src[0]); + float x1 = ggml_fp16_to_fp32(src[1]); dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); @@ -7735,7 +7739,7 @@ static void ggml_compute_forward_flash_attn_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - const float scale = 1.0/sqrt((double) D); + const float scale = 1.0f/sqrtf(D); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); @@ -7782,7 +7786,7 @@ static void ggml_compute_forward_flash_attn_f32( float max = -INFINITY; ggml_vec_max_f32(M, &max, S); - float sum = 0.0f; + ggml_float sum = 0.0; { #ifdef GGML_SOFT_MAX_ACCELERATE max = -max; @@ -7803,7 +7807,7 @@ static void ggml_compute_forward_flash_attn_f32( ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; + sump[j] += (ggml_float)val; SS[j] = val; } } @@ -7815,7 +7819,7 @@ static void ggml_compute_forward_flash_attn_f32( #endif } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(M, S, sum); @@ -7944,7 +7948,7 @@ static void ggml_compute_forward_flash_attn_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - const float scale = 1.0/sqrt((double) D); + const float scale = 1.0f/sqrtf(D); //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); @@ -8008,7 +8012,7 @@ static void ggml_compute_forward_flash_attn_f16( float max = -INFINITY; ggml_vec_max_f32(M, &max, S); - float sum = 0.0f; + ggml_float sum = 0.0; { #ifdef GGML_SOFT_MAX_ACCELERATE max = -max; @@ -8029,7 +8033,7 @@ static void ggml_compute_forward_flash_attn_f16( ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += val; + sump[j] += (ggml_float)val; SS[j] = val; } } @@ -8041,7 +8045,7 @@ static void ggml_compute_forward_flash_attn_f16( #endif } - assert(sum > 0.0f); + assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(M, S, sum); @@ -9566,7 +9570,7 @@ label=\"%d [%d, %d] | %s", fprintf(fp, " \"%p\" [ \ style = filled; fillcolor = %s; shape = record; \ label=\"%.1e\"; ]\n", - (void *) node, color, ggml_get_f32_1d(node, 0)); + (void *) node, color, (double)ggml_get_f32_1d(node, 0)); } else { fprintf(fp, " \"%p\" [ \ style = filled; fillcolor = %s; shape = record; \ @@ -9804,7 +9808,7 @@ static enum ggml_opt_result ggml_opt_adam( if (params.past <= t) { const float rate = (pf[t%params.past] - fx)/fx; - if (fabs(rate) < params.delta) { + if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } @@ -9883,7 +9887,7 @@ static enum ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; - if (*step <= 0.) { + if (*step <= 0.f) { return GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -9971,7 +9975,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_cgraph * gb) { if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { return GGML_OPT_INVALID_WOLFE; } } @@ -10092,8 +10096,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - if (xnorm < 1.0) { - xnorm = 1.0; + if (xnorm < 1.0f) { + xnorm = 1.0f; } if (gnorm/xnorm <= params.lbfgs.eps) { // converged @@ -10106,7 +10110,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (params.past <= k) { const float rate = (pf[k%params.past] - fx)/fx; - if (fabs(rate) < params.delta) { + if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } diff --git a/llama.cpp b/llama.cpp index b0eab2e72..4e0071183 100644 --- a/llama.cpp +++ b/llama.cpp @@ -922,7 +922,7 @@ static bool llama_eval_internal( struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))); + ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); @@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k( // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (plogits[i] < 0.0) { - logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair((double)plogits[i]*scale*repeat_penalty, i)); } else { - logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + logits_id.push_back(std::make_pair((double)plogits[i]*scale/repeat_penalty, i)); } } else { - logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + logits_id.push_back(std::make_pair((double)plogits[i]*scale, i)); } } } @@ -1310,8 +1310,8 @@ static llama_vocab::id llama_sample_top_p_top_k( p /= sum; } - if (top_p < 1.0f) { - double cumsum = 0.0f; + if (top_p < 1.0) { + double cumsum = 0.0; for (int i = 0; i < (int) probs.size(); i++) { cumsum += probs[i]; if (cumsum >= top_p) { @@ -1590,7 +1590,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s } for (int i = 0; i < (int) hist_cur.size(); ++i) { - printf("%5.3f ", hist_cur[i] / (float)nelements); + printf("%5.3f ", hist_cur[i] / (double)nelements); } printf("\n"); } else { @@ -1613,7 +1613,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s printf("%s: hist: ", __func__); for (int i = 0; i < (int) hist_all.size(); ++i) { - printf("%5.3f ", hist_all[i] / (float)sum_all); + printf("%5.3f ", hist_all[i] / (double)sum_all); } printf("\n"); } @@ -1828,11 +1828,11 @@ void llama_print_timings(struct llama_context * ctx) { const int32_t n_p_eval = std::max(1, ctx->n_p_eval); fprintf(stderr, "\n"); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); - fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample); - fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval); - fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_eval_us, n_eval, 1e-3f * ctx->t_eval_us / n_eval); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); + fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample); + fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval); + fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); } void llama_reset_timings(struct llama_context * ctx) {