metal : fix mat-vec Q4_K kernel for QK_K == 64
This commit is contained in:
parent
a8b9bb4566
commit
5865b18eeb
1 changed files with 4 additions and 4 deletions
|
@ -3018,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
const int ix = tiisg/4; // 0...7
|
const int ix = tiisg/4; // 0...7
|
||||||
const int it = tiisg%4; // 0...3
|
const int it = tiisg%4; // 0...3
|
||||||
|
@ -3028,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
const int first_row = r0 * N_DST;
|
||||||
const int ib_row = first_row * nb;
|
const int ib_row = first_row * nb;
|
||||||
|
|
||||||
const uint i12 = im%ne12;
|
const uint i12 = im%ne12;
|
||||||
|
@ -3094,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue