metal: update rms_norm kernel
This commit double the speed of rms_norm operations by using 512 threads per threadgroup, combining with SIMD primitives to minimize the need for thread group barriers.
This commit is contained in:
parent
bbce392890
commit
4088df14ca
2 changed files with 27 additions and 16 deletions
|
@ -792,7 +792,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
const float eps = 1e-6f;
|
const float eps = 1e-6f;
|
||||||
|
|
||||||
const int nth = 256;
|
const int nth = 512;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
@ -800,7 +800,7 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
|
|
@ -339,26 +339,33 @@ kernel void kernel_rms_norm(
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
threadgroup float * sum [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
|
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||||
|
device const float * x_scalar = (device const float *) x;
|
||||||
|
float4 sumf=0;
|
||||||
|
float all_sum=0;
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
sum[tpitg] = 0.0f;
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
sumf += x[i00] * x[i00];
|
||||||
sum[tpitg] += x[i00] * x[i00];
|
}
|
||||||
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
||||||
|
all_sum = simd_sum(all_sum);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
sum[sgitg] = all_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
// broadcast, simd group number is ntg / 32
|
||||||
if (tpitg < i) {
|
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
sum[tpitg] += sum[tpitg + i];
|
if (tpitg < i) {
|
||||||
}
|
sum[tpitg] += sum[tpitg + i];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
|
||||||
if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
|
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
||||||
sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,10 +374,14 @@ kernel void kernel_rms_norm(
|
||||||
const float mean = sum[0];
|
const float mean = sum[0];
|
||||||
const float scale = 1.0f/sqrt(mean + eps);
|
const float scale = 1.0f/sqrt(mean + eps);
|
||||||
|
|
||||||
device float * y = dst + tgpig*ne00;
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
device float * y_scalar = (device float *) y;
|
||||||
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
|
if (tpitg == 0) {
|
||||||
|
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue