all : prefer float over double where appropriate
This commit is contained in:
parent
f68345e9b1
commit
61733d3b49
7 changed files with 69 additions and 65 deletions
10
Makefile
10
Makefile
|
@ -31,12 +31,14 @@ endif
|
||||||
#
|
#
|
||||||
|
|
||||||
# keep standard at C11 and C++11
|
# keep standard at C11 and C++11
|
||||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC \
|
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||||
-Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
|
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC \
|
|
||||||
-Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion
|
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
|
|
||||||
|
# warnings
|
||||||
|
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
|
||||||
|
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
|
||||||
|
|
||||||
# OS specific
|
# OS specific
|
||||||
# TODO: support Windows
|
# TODO: support Windows
|
||||||
ifeq ($(UNAME_S),Linux)
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
|
|
@ -89,7 +89,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto embeddings = llama_get_embeddings(ctx);
|
const auto embeddings = llama_get_embeddings(ctx);
|
||||||
|
|
||||||
for (int i = 0; i < n_embd; i++) {
|
for (int i = 0; i < n_embd; i++) {
|
||||||
printf("%f ", (double)embeddings[i]);
|
printf("%f ", embeddings[i]);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
|
@ -210,7 +210,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
|
||||||
(double)params.temp, params.top_k, (double)params.top_p, params.repeat_last_n, (double)params.repeat_penalty);
|
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
|
||||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
@ -275,10 +275,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
// out of user input, sample next token
|
// out of user input, sample next token
|
||||||
const int top_k = params.top_k;
|
const int32_t top_k = params.top_k;
|
||||||
const double top_p = (double)params.top_p;
|
const float top_p = params.top_p;
|
||||||
const double temp = (double)params.temp;
|
const float temp = params.temp;
|
||||||
const double repeat_penalty = (double)params.repeat_penalty;
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
|
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
std::vector<double> softmax(const std::vector<float>& logits) {
|
std::vector<float> softmax(const std::vector<float>& logits) {
|
||||||
std::vector<double> probs(logits.size());
|
std::vector<float> probs(logits.size());
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
for (float v : logits) max_logit = std::max(max_logit, v);
|
for (float v : logits) max_logit = std::max(max_logit, v);
|
||||||
double sum_exp = 0.0;
|
double sum_exp = 0.0;
|
||||||
for (size_t i = 0; i < logits.size(); i++) {
|
for (size_t i = 0; i < logits.size(); i++) {
|
||||||
// Subtract the maximum logit value from the current logit value for numerical stability
|
// Subtract the maximum logit value from the current logit value for numerical stability
|
||||||
float logit = logits[i] - max_logit;
|
const float logit = logits[i] - max_logit;
|
||||||
double exp_logit = std::exp((double)logit);
|
const float exp_logit = std::expf(logit);
|
||||||
sum_exp += exp_logit;
|
sum_exp += exp_logit;
|
||||||
probs[i] = exp_logit;
|
probs[i] = exp_logit;
|
||||||
}
|
}
|
||||||
|
@ -24,14 +24,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
|
||||||
int seq_count = tokens.size() / params.n_ctx;
|
int seq_count = tokens.size() / params.n_ctx;
|
||||||
|
|
||||||
|
double nll = 0.0;
|
||||||
|
|
||||||
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
|
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
|
||||||
|
|
||||||
for (int i = 0; i < seq_count; ++i) {
|
for (int i = 0; i < seq_count; ++i) {
|
||||||
int start = i * params.n_ctx;
|
int start = i * params.n_ctx;
|
||||||
int end = start + params.n_ctx - 1;
|
int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
|
||||||
|
// it is better to always be power of 2 for better performance
|
||||||
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
||||||
auto start_t = std::chrono::high_resolution_clock::now();
|
auto start_t = std::chrono::high_resolution_clock::now();
|
||||||
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
|
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
|
||||||
|
@ -40,7 +42,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
auto end_t = std::chrono::high_resolution_clock::now();
|
auto end_t = std::chrono::high_resolution_clock::now();
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
double seconds = std::chrono::duration<double>(end_t - start_t).count();
|
const float seconds = std::chrono::duration<float>(end_t - start_t).count();
|
||||||
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
|
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
|
||||||
}
|
}
|
||||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||||
|
@ -63,7 +65,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
std::vector<float> tok_logits(
|
std::vector<float> tok_logits(
|
||||||
logits + j * n_vocab,
|
logits + j * n_vocab,
|
||||||
logits + (j + 1) * n_vocab);
|
logits + (j + 1) * n_vocab);
|
||||||
double prob = softmax(tok_logits)[tokens[start + j + 1]];
|
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
||||||
nll += -std::log(prob);
|
nll += -std::log(prob);
|
||||||
++count;
|
++count;
|
||||||
}
|
}
|
||||||
|
|
46
ggml.c
46
ggml.c
|
@ -150,10 +150,10 @@ typedef double ggml_float;
|
||||||
//
|
//
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
|
|
||||||
#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
|
#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
|
||||||
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
|
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
|
||||||
|
|
||||||
#define GGML_FP16_TO_FP32(x) (x)
|
#define GGML_FP16_TO_FP32(x) ((float) (x))
|
||||||
#define GGML_FP32_TO_FP16(x) (x)
|
#define GGML_FP32_TO_FP16(x) (x)
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
@ -322,7 +322,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
// note: do not use these inside ggml.c
|
// note: do not use these inside ggml.c
|
||||||
// these are meant to be used via the ggml.h API
|
// these are meant to be used via the ggml.h API
|
||||||
float ggml_fp16_to_fp32(ggml_fp16_t x) {
|
float ggml_fp16_to_fp32(ggml_fp16_t x) {
|
||||||
return GGML_FP16_TO_FP32(x);
|
return (float) GGML_FP16_TO_FP32(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
||||||
|
@ -566,7 +566,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
||||||
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
||||||
|
|
||||||
const float d = amax / ((1 << 3) - 1);
|
const float d = amax / ((1 << 3) - 1);
|
||||||
const float id = d ? 1.0/d : 0.0;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
y[i].d = d;
|
y[i].d = d;
|
||||||
|
|
||||||
|
@ -1001,7 +1001,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
} \
|
} \
|
||||||
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
||||||
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
||||||
res = vaddvq_f32(vaddq_f32(t0, t1)); \
|
res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GGML_F16_VEC GGML_F16x8
|
#define GGML_F16_VEC GGML_F16x8
|
||||||
|
@ -1505,7 +1505,7 @@ static inline __m512 dot_q4_0_oneblock_avx512(
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
||||||
float sumf = 0.0f;
|
ggml_float sumf = 0.0;
|
||||||
|
|
||||||
#if defined(GGML_SIMD)
|
#if defined(GGML_SIMD)
|
||||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
@ -1529,11 +1529,11 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||||
|
|
||||||
// leftovers
|
// leftovers
|
||||||
for (int i = np; i < n; ++i) {
|
for (int i = np; i < n; ++i) {
|
||||||
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
|
sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
|
sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -1549,7 +1549,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||||
const block_q4_0 * restrict x = vx;
|
const block_q4_0 * restrict x = vx;
|
||||||
const block_q4_0 * restrict y = vy;
|
const block_q4_0 * restrict y = vy;
|
||||||
|
|
||||||
float sumf = 0.0;
|
ggml_float sumf = 0.0;
|
||||||
|
|
||||||
#if defined(__ARM_NEON)
|
#if defined(__ARM_NEON)
|
||||||
float sum0 = 0.0f;
|
float sum0 = 0.0f;
|
||||||
|
@ -1644,7 +1644,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = sum0 + sum1;
|
sumf = (ggml_float)(sum0 + sum1);
|
||||||
#elif defined(__AVX512F__)
|
#elif defined(__AVX512F__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m512 acc0 = _mm512_setzero_ps();
|
__m512 acc0 = _mm512_setzero_ps();
|
||||||
|
@ -1936,7 +1936,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
||||||
// compute GGML_VEC_DOT_UNROLL dot products at once
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
||||||
// xs - x row stride in bytes
|
// xs - x row stride in bytes
|
||||||
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
||||||
float sumf[GGML_VEC_DOT_UNROLL] = { 0.0f };
|
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
||||||
|
|
||||||
|
@ -1972,13 +1972,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
|
||||||
// leftovers
|
// leftovers
|
||||||
for (int i = np; i < n; ++i) {
|
for (int i = np; i < n; ++i) {
|
||||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||||
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
|
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||||
sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
|
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -6998,16 +6998,16 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const int p = (mode == 0 ? n_past + i2 : i2);
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
||||||
for (int i1 = 0; i1 < ne1; i1++) {
|
for (int i1 = 0; i1 < ne1; i1++) {
|
||||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
const double theta = pow(10000.0, ((double)-i0)/n_dims);
|
const float theta = powf(10000.0, ((float)-i0)/n_dims);
|
||||||
|
|
||||||
const double cos_theta = cos(p*theta);
|
const float cos_theta = cosf(p*theta);
|
||||||
const double sin_theta = sin(p*theta);
|
const float sin_theta = sinf(p*theta);
|
||||||
|
|
||||||
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
double x0 = (double)src[0];
|
const float x0 = src[0];
|
||||||
double x1 = (double)src[1];
|
const float x1 = src[1];
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
|
@ -7054,16 +7054,16 @@ static void ggml_compute_forward_rope_f16(
|
||||||
const int p = (mode == 0 ? n_past + i2 : i2);
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
||||||
for (int i1 = 0; i1 < ne1; i1++) {
|
for (int i1 = 0; i1 < ne1; i1++) {
|
||||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
const double theta = pow(10000.0, ((double)-i0)/n_dims);
|
const float theta = powf(10000.0, ((float)-i0)/n_dims);
|
||||||
|
|
||||||
const float cos_theta = cos(p*theta);
|
const float cos_theta = cosf(p*theta);
|
||||||
const float sin_theta = sin(p*theta);
|
const float sin_theta = sinf(p*theta);
|
||||||
|
|
||||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
float x0 = ggml_fp16_to_fp32(src[0]);
|
const float x0 = ggml_fp16_to_fp32(src[0]);
|
||||||
float x1 = ggml_fp16_to_fp32(src[1]);
|
const float x1 = ggml_fp16_to_fp32(src[1]);
|
||||||
|
|
||||||
dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
|
dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
|
||||||
dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
|
dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
|
||||||
|
|
40
llama.cpp
40
llama.cpp
|
@ -779,8 +779,8 @@ static bool llama_model_load(
|
||||||
|
|
||||||
// progress
|
// progress
|
||||||
if (progress_callback) {
|
if (progress_callback) {
|
||||||
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
|
float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
|
||||||
double current_progress = (double(i) + current_file_progress) / double(n_parts);
|
float current_progress = (float(i) + current_file_progress) / float(n_parts);
|
||||||
progress_callback(current_progress, progress_callback_user_data);
|
progress_callback(current_progress, progress_callback_user_data);
|
||||||
}
|
}
|
||||||
if (model.n_loaded % 8 == 0) {
|
if (model.n_loaded % 8 == 0) {
|
||||||
|
@ -1240,12 +1240,12 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
|
||||||
static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
|
static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
|
||||||
// find the top k tokens
|
// find the top k tokens
|
||||||
std::partial_sort(
|
std::partial_sort(
|
||||||
logits_id.begin(),
|
logits_id.begin(),
|
||||||
logits_id.begin() + top_k, logits_id.end(),
|
logits_id.begin() + top_k, logits_id.end(),
|
||||||
[](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
|
[](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
|
||||||
return a.first > b.first;
|
return a.first > b.first;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1256,9 +1256,9 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
const std::vector<llama_vocab::id> & last_n_tokens,
|
const std::vector<llama_vocab::id> & last_n_tokens,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
float top_p,
|
||||||
double temp,
|
float temp,
|
||||||
double repeat_penalty) {
|
float repeat_penalty) {
|
||||||
auto & rng = lctx.rng;
|
auto & rng = lctx.rng;
|
||||||
|
|
||||||
const int n_logits = lctx.model.hparams.n_vocab;
|
const int n_logits = lctx.model.hparams.n_vocab;
|
||||||
|
@ -1266,41 +1266,41 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
||||||
const auto & logits = lctx.logits;
|
const auto & logits = lctx.logits;
|
||||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||||
|
|
||||||
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
std::vector<std::pair<float, llama_vocab::id>> logits_id;
|
||||||
logits_id.reserve(n_logits);
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
{
|
{
|
||||||
const double scale = 1.0/temp;
|
const float scale = 1.0f/temp;
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
if (plogits[i] < 0.0f) {
|
if (plogits[i] < 0.0f) {
|
||||||
logits_id.push_back(std::make_pair((double)plogits[i]*scale*repeat_penalty, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||||
} else {
|
} else {
|
||||||
logits_id.push_back(std::make_pair((double)plogits[i]*scale/repeat_penalty, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logits_id.push_back(std::make_pair((double)plogits[i]*scale, i));
|
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sample_top_k(logits_id, top_k);
|
sample_top_k(logits_id, top_k);
|
||||||
|
|
||||||
double maxl = -std::numeric_limits<double>::infinity();
|
float maxl = -std::numeric_limits<float>::infinity();
|
||||||
for (const auto & kv : logits_id) {
|
for (const auto & kv : logits_id) {
|
||||||
maxl = std::max(maxl, kv.first);
|
maxl = std::max(maxl, kv.first);
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute probs for the top k tokens
|
// compute probs for the top k tokens
|
||||||
std::vector<double> probs;
|
std::vector<float> probs;
|
||||||
probs.reserve(logits_id.size());
|
probs.reserve(logits_id.size());
|
||||||
|
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (const auto & kv : logits_id) {
|
for (const auto & kv : logits_id) {
|
||||||
double p = exp(kv.first - maxl);
|
const float p = expf(kv.first - maxl);
|
||||||
probs.push_back(p);
|
probs.push_back(p);
|
||||||
sum += p;
|
sum += p;
|
||||||
}
|
}
|
||||||
|
@ -1590,7 +1590,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
||||||
printf("%5.3f ", hist_cur[i] / (double)nelements);
|
printf("%5.3f ", hist_cur[i] / float(nelements));
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
} else {
|
} else {
|
||||||
|
@ -1613,7 +1613,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
|
|
||||||
printf("%s: hist: ", __func__);
|
printf("%s: hist: ", __func__);
|
||||||
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
||||||
printf("%5.3f ", hist_all[i] / (double)sum_all);
|
printf("%5.3f ", hist_all[i] / float(sum_all));
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
@ -1795,9 +1795,9 @@ llama_token llama_sample_top_p_top_k(
|
||||||
const llama_token * last_n_tokens_data,
|
const llama_token * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
float top_p,
|
||||||
double temp,
|
float temp,
|
||||||
double repeat_penalty) {
|
float repeat_penalty) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
llama_token result = 0;
|
llama_token result = 0;
|
||||||
|
|
8
llama.h
8
llama.h
|
@ -45,7 +45,7 @@ extern "C" {
|
||||||
|
|
||||||
} llama_token_data;
|
} llama_token_data;
|
||||||
|
|
||||||
typedef void (*llama_progress_callback)(double progress, void *ctx);
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
int n_ctx; // text context
|
int n_ctx; // text context
|
||||||
|
@ -134,9 +134,9 @@ extern "C" {
|
||||||
const llama_token * last_n_tokens_data,
|
const llama_token * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
int top_k,
|
int top_k,
|
||||||
double top_p,
|
float top_p,
|
||||||
double temp,
|
float temp,
|
||||||
double repeat_penalty);
|
float repeat_penalty);
|
||||||
|
|
||||||
// Performance information
|
// Performance information
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue