Fix kernel_norm broken by ca82cf7

This commit is contained in:
Iwan Kawrakow 2023-09-07 15:17:43 +02:00
parent fec2fb19e4
commit f9c8ccc12f

View file

@ -220,14 +220,10 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
//// broadcast const float mean = sum[0] / ne00;
//if (tpitg == 0) {
// sum[0] /= ne00;
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0];
// recenter and VARIANCE // recenter and VARIANCE
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * y = dst + tgpig*ne00; device float * y = dst + tgpig*ne00;
sum[tpitg] = 0.0f; sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@ -235,12 +231,6 @@ kernel void kernel_norm(
sum[tpitg] += y[i00] * y[i00]; sum[tpitg] += y[i00] * y[i00];
} }
//// VARIANCE
//// parallel sum
//sum[tpitg] = 0.0f;
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
// sum[tpitg] += y[i00] * y[i00];
//}
// reduce // reduce
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) { for (uint i = ntg/2; i > 0; i /= 2) {
@ -249,12 +239,7 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
//// broadcast const float variance = sum[0] / ne00;
//if (tpitg == 0) {
// sum[0] /= ne00;
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + eps); const float scale = 1.0f/sqrt(variance + eps);
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {