From a527eccb43c3ac72ff5458961eec2435a4e0cdfe Mon Sep 17 00:00:00 2001 From: lshzh-ww Date: Tue, 15 Aug 2023 11:31:13 -0400 Subject: [PATCH] metal: fix bugs for GQA and perplexity test. I mixed up ne02 and nb02 in previous commit. --- ggml-metal.metal | 69 +++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 67b1f1303..3f3125236 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0); + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne12; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne12 + first_row + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } @@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32( const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; @@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } } @@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r2 = tgpig.z; const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 @@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; } } @@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; @@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float yl[8]; float yh[8]; float sumf[N_DST]={0.f}, all_sum; @@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum; + dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf[2]={0.f}; @@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + first_row + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; } } @@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32( const int r2 = tgpig.z; const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(ne02/QK_K); + const uint offset0 = r2/gqa*(nb*ne0); device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; float sumf = 0; @@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + r2*ne12 + row] = tot; + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; } } @@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0, uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12; + + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory @@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0, if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne12; + + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); } @@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0, } threadgroup_barrier(mem_flags::mem_threadgroup); - device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12; + device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; if (sgitg==0) { for (int i = 0; i < n_rows; i++) { for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {