metal : assert various kernel requirements
This commit is contained in:
parent
42833bc7a8
commit
0f8df395ce
2 changed files with 25 additions and 13 deletions
|
@ -345,10 +345,11 @@ kernel void kernel_rms_norm(
|
|||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
float4 sumf=0;
|
||||
float all_sum=0;
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
|
@ -361,6 +362,7 @@ kernel void kernel_rms_norm(
|
|||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
|
@ -368,7 +370,9 @@ kernel void kernel_rms_norm(
|
|||
}
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||
sum[0] += x_scalar[i];
|
||||
}
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
|
||||
|
@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
|
|||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||
y_scalar[i00] = x_scalar[i00] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue