replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large

replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large for vectors larger than 128 elements.
This commit is contained in:
Alexander Komarov 2024-05-24 11:52:13 -07:00 committed by GitHub
parent cd2322c996
commit b3d55bcc72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4(
}
}
kernel void kernel_mul_mv_f16_f32_l4_large(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int nrows = ne11;
const int64_t base_r0 = tgpig.x*32;
const int64_t im = tgpig.z;
threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group
const uint i12 = im%ne12;
const uint i13 = im/ne12;
for (int j = 0; j < 32; ++j) {
const int64_t r0 = base_r0 + j;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half4 * x4 = (device const half4 *) (src0 + offset0);
partial_sums[tiisg] = 0.0f;
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k];
}
// Barrier to ensure all threads have written their partial sums
threadgroup_barrier(mem_flags::mem_threadgroup);
float sumf = simd_sum(partial_sums[tiisg]);
// Barrier to ensure reduction is complete before writing the result
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
}
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));