From 29a0f8b94014ebf5a9f9197cb2ed5f55aeac9804 Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 1 May 2023 20:02:48 +0200 Subject: [PATCH] fix softmax in baby-llama example --- examples/baby-llama/baby-llama.cpp | 36 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 7f538479d..9feafae98 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -520,10 +520,16 @@ void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, str ggml_set_i32_1d(best_samples, i, k); } } + float psum = 0; for (int k = 0; k < logits->ne[0]; ++k) { float logit = ggml_get_f32_1d(logits, i * logits->ne[0] + k); - float p = expf(logit - max_logit); - ggml_set_i32_1d(probs, i * probs->ne[0] + k, p); + float p = (logit == -INFINITY) ? 0 : expf(logit - max_logit); + psum += p; + ggml_set_f32_1d(probs, i * probs->ne[0] + k, p); + } + for (int k = 0; k < logits->ne[0]; ++k) { + float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); + ggml_set_f32_1d(probs, i * probs->ne[0] + k, p / psum); } } } @@ -532,7 +538,7 @@ void print_probs(struct ggml_tensor * probs) { assert(probs->n_dims == 2); for (int i=0; ine[1]; ++i) { for (int k = 0; k < probs->ne[0]; ++k) { - float p = ggml_get_f32_1d(probs, i*probs->ne[1] + k); + float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); printf(" %.1f", p); } printf("\n"); @@ -588,11 +594,11 @@ int main(int argc, char ** argv) { int n_tokens = 64; struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); - struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); - struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); + struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens); for (int i=0; i