From 19fb91899bad98423a5a14cea8cfa22a3334432d Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 15 May 2023 14:19:38 +0200 Subject: [PATCH] better weight initialization improves training convergence at start --- examples/baby-llama/baby-llama-text.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index 9f2ff9034..b5177ed5b 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -51,38 +51,43 @@ float frand_uniform(struct random_uniform_distribution * rnd) { } struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { + float scale = 1.0f; // xavier switch (tensor->n_dims) { case 1: + scale /= sqrtf(tensor->ne[0]); for (int i0 = 0; i0 < tensor->ne[0]; i0++) { float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); - *dst = frand_normal(rnd); + *dst = scale * frand_normal(rnd); } break; case 2: + scale /= sqrtf(tensor->ne[0]*tensor->ne[1]); for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) { float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); - *dst = frand_normal(rnd); + *dst = scale * frand_normal(rnd); } } break; case 3: + scale /= sqrtf(tensor->ne[0]*tensor->ne[1]); for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) { float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); - *dst = frand_normal(rnd); + *dst = scale * frand_normal(rnd); } } } break; case 4: + scale /= sqrtf(tensor->ne[0]*tensor->ne[1]); for (int i3 = 0; i3 < tensor->ne[3]; i3++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) { float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); - *dst = frand_normal(rnd); + *dst = scale * frand_normal(rnd); } } }