diff --git a/ggml-metal.metal b/ggml-metal.metal index 689411903..74a5e0b03 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2560,12 +2560,16 @@ typedef struct { uint8_t qs[QK4_NL/2]; } block_iq4_nl; +#if QK_K == 64 +#define block_iq4_xs block_iq4_nl +#else typedef struct { half d; uint16_t scales_h; uint8_t scales_l[QK_K/64]; uint8_t qs[QK_K/2]; } block_iq4_xs; +#endif //====================================== dot products ========================= @@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); } -#if QK_K == 256 const int ix = tiisg; device const float * y4 = y + 32 * ix; @@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } -#else - (void) x; - (void) y; - (void) yl; - (void) nb32; -#endif for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); } -#if QK_K == 256 const int ix = tiisg; device const float * y4 = y + 32 * ix; @@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } -#else - (void) x; - (void) y; - (void) yl; - (void) nb32; -#endif for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); } -#if QK_K == 256 const int ix = tiisg; device const float * y4 = y + 32 * ix; @@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } -#else - (void) x; - (void) y; - (void) yl; - (void) nb32; -#endif for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl( const int nb32 = nb * (QK_K / 32); -#if QK_K == 256 const int ix = tiisg/2; const int il = tiisg%2; @@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 16 * 32; } -#else - (void) x; - (void) y; - (void) yl; - (void) nb32; -#endif for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +#if QK_K != 64 void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, @@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( } } } +#endif [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( @@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { +#if QK_K == 64 + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#else kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +#endif } //============================= templates and their specializations ============================= @@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 template void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { +#if QK_K == 64 + dequantize_iq4_nl(xb, il, reg); +#else // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 const int ib32 = il/2; il = il%2; @@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; } +#endif } template @@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; +#if QK_K == 64 +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#else template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +#endif // // matrix-matrix multiplication @@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#else template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif // // indirect matrix-matrix multiplication @@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#if QK_K == 64 +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#else template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#endif // // matrix-vector multiplication @@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32( const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; +#if QK_K == 64 + kernel_mul_mv_iq4_nl_f32_impl( +#else kernel_mul_mv_iq4_xs_f32_impl( +#endif src0[id], (device const float *) (src1 + bid*nb11), dst + bid*ne0, diff --git a/ggml-quants.c b/ggml-quants.c index da186cb75..c21ba6e38 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10262,7 +10262,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const int nb = n / QK_K; -#if defined __ARM_NEON +#if defined __ARM_NEON && QK_K != 64 const uint8x16_t m8 = vdupq_n_u8(0x08); const uint8x16_t m7 = vdupq_n_u8(0x07);