From f3cf7df21fcd95ca618caa6c57e63483cdc3eb12 Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 15 May 2023 14:18:57 +0200 Subject: [PATCH] better weight initialization improves training convergence at start --- examples/baby-llama/baby-llama.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 5573c154b..e5639da37 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -79,34 +79,39 @@ struct ggml_tensor * randomize_tensor_normal( int ndims, const int64_t ne[], struct random_normal_distribution * rnd) { + float scale = 1.0; // xavier switch (ndims) { case 1: + scale /= sqrtf(ne[0]); for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i0] = frand_normal(rnd); + ((float *)tensor->data)[i0] = scale * frand_normal(rnd); } break; case 2: + scale /= sqrtf(ne[0]+ne[1]); for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd); } } break; case 3: + scale /= sqrtf(ne[0]+ne[1]); for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } } break; case 4: + scale /= sqrtf(ne[0]+ne[1]); for (int i3 = 0; i3 < ne[3]; i3++) { for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } }