diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 81f1aeedc..8e3a9de23 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1629,12 +1629,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1711,7 +1711,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1720,9 +1720,9 @@ kernel void kernel_mul_mv_q4_1_f32( device float * dst, constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1733,7 +1733,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1744,16 +1744,17 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 +template void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1829,17 +1830,17 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1909,7 +1910,7 @@ kernel void kernel_mul_mv( constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( + kernel_mul_mv_impl( src0, src1, dst, @@ -3931,11 +3932,12 @@ kernel void kernel_concat( } } +template void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4030,14 +4032,15 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4189,14 +4192,15 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4304,14 +4308,15 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4433,14 +4438,15 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4521,16 +4527,17 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit +template void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4629,14 +4636,15 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4745,14 +4753,15 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4854,14 +4863,15 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4963,14 +4973,15 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5073,14 +5084,15 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5159,11 +5171,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } +template void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5251,11 +5264,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +template void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5338,11 +5352,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +template void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5436,7 +5451,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5449,7 +5464,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5463,7 +5478,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5477,7 +5492,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } template @@ -6062,7 +6077,7 @@ void mmv_fn( impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id(