better weight initialization improves training convergence at start
This commit is contained in:
parent
f3cf7df21f
commit
19fb91899b
1 changed files with 9 additions and 4 deletions
|
@ -51,38 +51,43 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
|
|||
}
|
||||
|
||||
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
|
||||
float scale = 1.0f; // xavier
|
||||
switch (tensor->n_dims) {
|
||||
case 1:
|
||||
scale /= sqrtf(tensor->ne[0]);
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
|
||||
*dst = frand_normal(rnd);
|
||||
*dst = scale * frand_normal(rnd);
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
||||
*dst = frand_normal(rnd);
|
||||
*dst = scale * frand_normal(rnd);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
||||
*dst = frand_normal(rnd);
|
||||
*dst = scale * frand_normal(rnd);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
|
||||
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
|
||||
*dst = frand_normal(rnd);
|
||||
*dst = scale * frand_normal(rnd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue