fix random weight initialization scale

This commit is contained in:
xaedes 2023-05-21 12:18:47 +02:00
parent 96514971dd
commit 57c2f4f909
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -68,7 +68,7 @@ struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct
}
break;
case 2:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
@ -77,7 +77,7 @@ struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct
}
break;
case 3:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
@ -88,7 +88,7 @@ struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct
}
break;
case 4:
scale /= sqrtf(tensor->ne[0]*tensor->ne[1]);
scale /= sqrtf(tensor->ne[0]+tensor->ne[1]);
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {