diff --git a/Makefile b/Makefile index d419d9e61..9cfa89f7a 100644 --- a/Makefile +++ b/Makefile @@ -31,12 +31,14 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC \ - -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC \ - -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion +CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC +CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC LDFLAGS = +# warnings +CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function +CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function + # OS specific # TODO: support Windows ifeq ($(UNAME_S),Linux) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index c7eb81cd5..d397f35fd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -89,7 +89,7 @@ int main(int argc, char ** argv) { const auto embeddings = llama_get_embeddings(ctx); for (int i = 0; i < n_embd; i++) { - printf("%f ", (double)embeddings[i]); + printf("%f ", embeddings[i]); } printf("\n"); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3836562e8..3130aef0c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -210,7 +210,7 @@ int main(int argc, char ** argv) { } } fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", - (double)params.temp, params.top_k, (double)params.top_p, params.repeat_last_n, (double)params.repeat_penalty); + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); @@ -275,10 +275,10 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // out of user input, sample next token - const int top_k = params.top_k; - const double top_p = (double)params.top_p; - const double temp = (double)params.temp; - const double repeat_penalty = (double)params.repeat_penalty; + const int32_t top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; llama_token id = 0; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 41e50be7f..09693a05e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,15 +1,15 @@ #include "common.h" #include "llama.h" -std::vector softmax(const std::vector& logits) { - std::vector probs(logits.size()); +std::vector softmax(const std::vector& logits) { + std::vector probs(logits.size()); float max_logit = logits[0]; for (float v : logits) max_logit = std::max(max_logit, v); double sum_exp = 0.0; for (size_t i = 0; i < logits.size(); i++) { // Subtract the maximum logit value from the current logit value for numerical stability - float logit = logits[i] - max_logit; - double exp_logit = std::exp((double)logit); + const float logit = logits[i] - max_logit; + const float exp_logit = std::expf(logit); sum_exp += exp_logit; probs[i] = exp_logit; } @@ -24,14 +24,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) { auto tokens = ::llama_tokenize(ctx, params.prompt, true); int count = 0; - double nll = 0.0; int seq_count = tokens.size() / params.n_ctx; + double nll = 0.0; + fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); for (int i = 0; i < seq_count; ++i) { int start = i * params.n_ctx; - int end = start + params.n_ctx - 1; + int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512 + // it is better to always be power of 2 for better performance std::vector embd(tokens.begin() + start, tokens.begin() + end); auto start_t = std::chrono::high_resolution_clock::now(); if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { @@ -40,7 +42,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { } auto end_t = std::chrono::high_resolution_clock::now(); if (i == 0) { - double seconds = std::chrono::duration(end_t - start_t).count(); + const float seconds = std::chrono::duration(end_t - start_t).count(); printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); } // We get the logits for all the tokens in the context window (params.n_ctx) @@ -63,7 +65,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { std::vector tok_logits( logits + j * n_vocab, logits + (j + 1) * n_vocab); - double prob = softmax(tok_logits)[tokens[start + j + 1]]; + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; nll += -std::log(prob); ++count; } diff --git a/ggml.c b/ggml.c index 6838aa338..83395a701 100644 --- a/ggml.c +++ b/ggml.c @@ -150,10 +150,10 @@ typedef double ggml_float; // #include -#define GGML_COMPUTE_FP16_TO_FP32(x) (x) +#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) #define GGML_COMPUTE_FP32_TO_FP16(x) (x) -#define GGML_FP16_TO_FP32(x) (x) +#define GGML_FP16_TO_FP32(x) ((float) (x)) #define GGML_FP32_TO_FP16(x) (x) #else @@ -322,7 +322,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { // note: do not use these inside ggml.c // these are meant to be used via the ggml.h API float ggml_fp16_to_fp32(ggml_fp16_t x) { - return GGML_FP16_TO_FP32(x); + return (float) GGML_FP16_TO_FP32(x); } ggml_fp16_t ggml_fp32_to_fp16(float x) { @@ -566,7 +566,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); const float d = amax / ((1 << 3) - 1); - const float id = d ? 1.0/d : 0.0; + const float id = d ? 1.0f/d : 0.0f; y[i].d = d; @@ -1001,7 +1001,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in } \ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = vaddvq_f32(vaddq_f32(t0, t1)); \ + res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ } #define GGML_F16_VEC GGML_F16x8 @@ -1505,7 +1505,7 @@ static inline __m512 dot_q4_0_oneblock_avx512( #endif inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { - float sumf = 0.0f; + ggml_float sumf = 0.0; #if defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -1529,11 +1529,11 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t // leftovers for (int i = np; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } #else for (int i = 0; i < n; ++i) { - sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } #endif @@ -1549,7 +1549,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void const block_q4_0 * restrict x = vx; const block_q4_0 * restrict y = vy; - float sumf = 0.0; + ggml_float sumf = 0.0; #if defined(__ARM_NEON) float sum0 = 0.0f; @@ -1644,7 +1644,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void #endif } - sumf = sum0 + sum1; + sumf = (ggml_float)(sum0 + sum1); #elif defined(__AVX512F__) // Initialize accumulator with zeros __m512 acc0 = _mm512_setzero_ps(); @@ -1936,7 +1936,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - float sumf[GGML_VEC_DOT_UNROLL] = { 0.0f }; + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; @@ -1972,13 +1972,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re // leftovers for (int i = np; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); } } #else for (int i = 0; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); } } #endif @@ -6998,16 +6998,16 @@ static void ggml_compute_forward_rope_f32( const int p = (mode == 0 ? n_past + i2 : i2); for (int i1 = 0; i1 < ne1; i1++) { for (int i0 = 0; i0 < n_dims; i0 += 2) { - const double theta = pow(10000.0, ((double)-i0)/n_dims); + const float theta = powf(10000.0, ((float)-i0)/n_dims); - const double cos_theta = cos(p*theta); - const double sin_theta = sin(p*theta); + const float cos_theta = cosf(p*theta); + const float sin_theta = sinf(p*theta); const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - double x0 = (double)src[0]; - double x1 = (double)src[1]; + const float x0 = src[0]; + const float x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -7054,16 +7054,16 @@ static void ggml_compute_forward_rope_f16( const int p = (mode == 0 ? n_past + i2 : i2); for (int i1 = 0; i1 < ne1; i1++) { for (int i0 = 0; i0 < n_dims; i0 += 2) { - const double theta = pow(10000.0, ((double)-i0)/n_dims); + const float theta = powf(10000.0, ((float)-i0)/n_dims); - const float cos_theta = cos(p*theta); - const float sin_theta = sin(p*theta); + const float cos_theta = cosf(p*theta); + const float sin_theta = sinf(p*theta); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float x0 = ggml_fp16_to_fp32(src[0]); - float x1 = ggml_fp16_to_fp32(src[1]); + const float x0 = ggml_fp16_to_fp32(src[0]); + const float x1 = ggml_fp16_to_fp32(src[1]); dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); diff --git a/llama.cpp b/llama.cpp index 4e0071183..ee7eb8ea7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -779,8 +779,8 @@ static bool llama_model_load( // progress if (progress_callback) { - double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset); - double current_progress = (double(i) + current_file_progress) / double(n_parts); + float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset); + float current_progress = (float(i) + current_file_progress) / float(n_parts); progress_callback(current_progress, progress_callback_user_data); } if (model.n_loaded % 8 == 0) { @@ -1240,12 +1240,12 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // sampling // -static void sample_top_k(std::vector> & logits_id, int top_k) { +static void sample_top_k(std::vector> & logits_id, int top_k) { // find the top k tokens std::partial_sort( logits_id.begin(), logits_id.begin() + top_k, logits_id.end(), - [](const std::pair & a, const std::pair & b) { + [](const std::pair & a, const std::pair & b) { return a.first > b.first; }); @@ -1256,9 +1256,9 @@ static llama_vocab::id llama_sample_top_p_top_k( llama_context & lctx, const std::vector & last_n_tokens, int top_k, - double top_p, - double temp, - double repeat_penalty) { + float top_p, + float temp, + float repeat_penalty) { auto & rng = lctx.rng; const int n_logits = lctx.model.hparams.n_vocab; @@ -1266,41 +1266,41 @@ static llama_vocab::id llama_sample_top_p_top_k( const auto & logits = lctx.logits; const auto * plogits = logits.data() + logits.size() - n_logits; - std::vector> logits_id; + std::vector> logits_id; logits_id.reserve(n_logits); { - const double scale = 1.0/temp; + const float scale = 1.0f/temp; for (int i = 0; i < n_logits; ++i) { // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability if (plogits[i] < 0.0f) { - logits_id.push_back(std::make_pair((double)plogits[i]*scale*repeat_penalty, i)); + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); } else { - logits_id.push_back(std::make_pair((double)plogits[i]*scale/repeat_penalty, i)); + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); } } else { - logits_id.push_back(std::make_pair((double)plogits[i]*scale, i)); + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); } } } sample_top_k(logits_id, top_k); - double maxl = -std::numeric_limits::infinity(); + float maxl = -std::numeric_limits::infinity(); for (const auto & kv : logits_id) { maxl = std::max(maxl, kv.first); } // compute probs for the top k tokens - std::vector probs; + std::vector probs; probs.reserve(logits_id.size()); double sum = 0.0; for (const auto & kv : logits_id) { - double p = exp(kv.first - maxl); + const float p = expf(kv.first - maxl); probs.push_back(p); sum += p; } @@ -1590,7 +1590,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s } for (int i = 0; i < (int) hist_cur.size(); ++i) { - printf("%5.3f ", hist_cur[i] / (double)nelements); + printf("%5.3f ", hist_cur[i] / float(nelements)); } printf("\n"); } else { @@ -1613,7 +1613,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s printf("%s: hist: ", __func__); for (int i = 0; i < (int) hist_all.size(); ++i) { - printf("%5.3f ", hist_all[i] / (double)sum_all); + printf("%5.3f ", hist_all[i] / float(sum_all)); } printf("\n"); } @@ -1795,9 +1795,9 @@ llama_token llama_sample_top_p_top_k( const llama_token * last_n_tokens_data, int last_n_tokens_size, int top_k, - double top_p, - double temp, - double repeat_penalty) { + float top_p, + float temp, + float repeat_penalty) { const int64_t t_start_sample_us = ggml_time_us(); llama_token result = 0; diff --git a/llama.h b/llama.h index d3f4cae61..f5a576c1e 100644 --- a/llama.h +++ b/llama.h @@ -45,7 +45,7 @@ extern "C" { } llama_token_data; - typedef void (*llama_progress_callback)(double progress, void *ctx); + typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { int n_ctx; // text context @@ -134,9 +134,9 @@ extern "C" { const llama_token * last_n_tokens_data, int last_n_tokens_size, int top_k, - double top_p, - double temp, - double repeat_penalty); + float top_p, + float temp, + float repeat_penalty); // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx);