replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large
replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large for vectors larger than 128 elements.
This commit is contained in:
parent
cd2322c996
commit
b3d55bcc72
1 changed files with 58 additions and 0 deletions
|
@ -1598,6 +1598,64 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_f16_f32_l4_large(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int nrows = ne11;
|
||||
const int64_t base_r0 = tgpig.x*32;
|
||||
const int64_t im = tgpig.z;
|
||||
threadgroup float partial_sums[32]; // Shared memory for partial sums for each SIMD group
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
const int64_t r0 = base_r0 + j;
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
||||
|
||||
partial_sums[tiisg] = 0.0f;
|
||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) partial_sums[tiisg] += (float) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
// Barrier to ensure all threads have written their partial sums
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float sumf = simd_sum(partial_sums[tiisg]);
|
||||
// Barrier to ensure reduction is complete before writing the result
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue