better weight initialization improves training convergence at start

This commit is contained in:
xaedes 2023-05-15 14:19:38 +02:00
parent f3cf7df21f
commit 19fb91899b
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -51,38 +51,43 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
} }
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
float scale = 1.0f; // xavier
switch (tensor->n_dims) { switch (tensor->n_dims) {
case 1: case 1:
scale /= sqrtf(tensor->ne[0]);
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
*dst = frand_normal(rnd); *dst = scale * frand_normal(rnd);
} }
break; break;
case 2: case 2:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
*dst = frand_normal(rnd); *dst = scale * frand_normal(rnd);
} }
} }
break; break;
case 3: case 3:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
*dst = frand_normal(rnd); *dst = scale * frand_normal(rnd);
} }
} }
} }
break; break;
case 4: case 4:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
for (int i3 = 0; i3 < tensor->ne[3]; i3++) { for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
*dst = frand_normal(rnd); *dst = scale * frand_normal(rnd);
} }
} }
} }