metal: fix bugs for GQA and perplexity test.

I mixed up ne02 and nb02 in previous commit.
This commit is contained in:
lshzh-ww 2023-08-15 11:31:13 -04:00
parent bfa455de43
commit a527eccb43

View file

@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
// N_DST, so this is another explicit assumption of the implementation. // N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw> template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
uint3 tgpig, uint tiisg, uint sgitg) { uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr; const int first_row = (r0 * nsg + sgitg) * nr;
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0); const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache float yl[16]; // src1 vector cache
float sumf[nr]={0.f}; float sumf[nr]={0.f};
@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
for (int row = 0; row < nr; ++row) { for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]); const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < ne01) { if (tiisg == 0 && first_row + row < ne01) {
dst[r1*ne0 + im*ne12 + first_row + row] = tot; dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
} }
} }
} }
@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mat_q4_1_f32(
@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mat_f16_f32(
@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32(
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[32]; float yl[32];
float sumf[N_DST]={0.f}, all_sum; float sumf[N_DST]={0.f}, all_sum;
@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]); all_sum = simd_sum(sumf[row]);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
} }
} }
} }
@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.z; const int64_t r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16]; float yl[16];
@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf); const float tot = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = tot; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
} }
} }
} }
@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.z; const int64_t r2 = tgpig.z;
const int row = 2 * r0 + sgitg; const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
const int ix = tiisg/4; const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int il = 4 * (tiisg%4);// 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1 const int im = il/8; // 0, 0, 1, 1
@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
const float tot = simd_sum(sumf); const float tot = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + row] = tot; dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
} }
} }
@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16]; float yl[16];
float yh[16]; float yh[16];
float sumf[N_DST]={0.f}, all_sum; float sumf[N_DST]={0.f}, all_sum;
@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]); all_sum = simd_sum(sumf[row]);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
} }
} }
} }
@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[8]; float yl[8];
float yh[8]; float yh[8];
float sumf[N_DST]={0.f}, all_sum; float sumf[N_DST]={0.f}, all_sum;
@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) { for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]); all_sum = simd_sum(sumf[row]);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum; dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
} }
} }
} }
@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf[2]={0.f}; float sumf[2]={0.f};
@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]); const float tot = simd_sum(sumf[row]);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = tot; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
} }
} }
@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]], constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int row = 2 * r0 + sgitg; const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K); const uint offset0 = r2/gqa*(nb*ne0);
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf = 0; float sumf = 0;
@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32(
const float tot = simd_sum(sumf); const float tot = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + row] = tot; dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
} }
} }
@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12; + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory //load data and store to threadgroup memory
@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne12; + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
} }
@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12; device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg==0) { if (sgitg==0) {
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {