diff --git a/ggml-metal.metal b/ggml-metal.metal index 625119460..f094a1d40 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -543,75 +543,6 @@ kernel void kernel_mul_mat_q4_1_f32( } } -//kernel void kernel_mul_mat_q4_1_f32( -// device const void * src0, -// device const float * src1, -// device float * dst, -// constant int64_t & ne00, -// constant int64_t & ne10, -// constant int64_t & ne0, -// threadgroup float * sum [[threadgroup(0)]], -// uint2 tgpig[[threadgroup_position_in_grid]], -// uint2 tpitg[[thread_position_in_threadgroup]], -// uint2 tptg[[threads_per_threadgroup]]) { -// const int nb = ne00/QK4_1; -// -// const int64_t r0 = tgpig.x; -// const int64_t r1 = tgpig.y; -// -// device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; -// device const float * y = (device const float *) src1 + r1*ne10; -// -// const uint nth = tptg.x*tptg.y; -// const uint ith = tptg.y*tpitg.x + tpitg.y; -// -// const int ix = tpitg.y/4; // 0 or 1 -// const int iy = tpitg.y - 4*ix; // 0...3 -// -// const int first = 4 * iy; -// -// float sumf = 0; -// -// for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { -// -// const float d = (float)x[i].d; -// const float m = (float)x[i].m; -// -// device const uint8_t * xl = x[i].qs + first; -// device const float * yl = y + i * QK4_1 + first; -// -// float2 acc = {0.0f, 0.0f}; -// -// for (int j = 0; j < 4; ++j) { -// -// acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); -// acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); -// -// } -// -// sumf += acc[0] + acc[1]; -// } -// -// sum[ith] = sumf; -// -// // -// // Accumulate the sum from all threads in the threadgroup -// // -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith%4 == 0) { -// sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; -// } -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith%16 == 0) { -// sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; -// } -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith == 0) { -// for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; -// dst[r1*ne0 + r0] = sum[0]; -// } -//} - kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1,