Fix kernel_norm broken by ca82cf7
This commit is contained in:
parent
fec2fb19e4
commit
f9c8ccc12f
1 changed files with 3 additions and 18 deletions
|
@ -220,14 +220,10 @@ kernel void kernel_norm(
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
const float mean = sum[0] / ne00;
|
||||||
//if (tpitg == 0) {
|
|
||||||
// sum[0] /= ne00;
|
|
||||||
//}
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
const float mean = sum[0];
|
|
||||||
|
|
||||||
// recenter and VARIANCE
|
// recenter and VARIANCE
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
sum[tpitg] = 0.0f;
|
sum[tpitg] = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
@ -235,12 +231,6 @@ kernel void kernel_norm(
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
//// VARIANCE
|
|
||||||
//// parallel sum
|
|
||||||
//sum[tpitg] = 0.0f;
|
|
||||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
// sum[tpitg] += y[i00] * y[i00];
|
|
||||||
//}
|
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
|
@ -249,12 +239,7 @@ kernel void kernel_norm(
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
const float variance = sum[0] / ne00;
|
||||||
//if (tpitg == 0) {
|
|
||||||
// sum[0] /= ne00;
|
|
||||||
//}
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
const float variance = sum[0];
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue