vulkan: account for lookup tables when checking shared memory size (#11502)
This commit is contained in:
		
							parent
							
								
									55ac8c7791
								
							
						
					
					
						commit
						98f6b0fd1e
					
				
					 1 changed files with 303 additions and 307 deletions
				
			
		|  | @ -184,12 +184,12 @@ struct vk_device_struct { | |||
| 
 | ||||
|     size_t idx; | ||||
| 
 | ||||
|     bool mul_mat_l; | ||||
|     bool mul_mat_m; | ||||
|     bool mul_mat_s; | ||||
|     bool mul_mat_id_l; | ||||
|     bool mul_mat_id_m; | ||||
|     bool mul_mat_id_s; | ||||
|     bool mul_mat_l[GGML_TYPE_COUNT]; | ||||
|     bool mul_mat_m[GGML_TYPE_COUNT]; | ||||
|     bool mul_mat_s[GGML_TYPE_COUNT]; | ||||
|     bool mul_mat_id_l[GGML_TYPE_COUNT]; | ||||
|     bool mul_mat_id_m[GGML_TYPE_COUNT]; | ||||
|     bool mul_mat_id_s[GGML_TYPE_COUNT]; | ||||
| 
 | ||||
|     // set to true to indicate that some shaders need to be compiled after the dryrun
 | ||||
|     bool need_compiles {}; | ||||
|  | @ -1378,7 +1378,33 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ | |||
|     return {64, 64}; | ||||
| }; | ||||
| 
 | ||||
| static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) { | ||||
| static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { | ||||
| 
 | ||||
|     uint32_t lut_size = 0; | ||||
|     switch (src0_type) { | ||||
|     case GGML_TYPE_IQ2_XXS: | ||||
|         lut_size = 8*256; | ||||
|         break; | ||||
|     case GGML_TYPE_IQ2_XS: | ||||
|         lut_size = 8*512; | ||||
|         break; | ||||
|     case GGML_TYPE_IQ2_S: | ||||
|         lut_size = 8*1024; | ||||
|         break; | ||||
|     case GGML_TYPE_IQ3_XXS: | ||||
|         lut_size = 4*256; | ||||
|         break; | ||||
|     case GGML_TYPE_IQ3_S: | ||||
|         lut_size = 4*512; | ||||
|         break; | ||||
|     case GGML_TYPE_IQ4_NL: | ||||
|     case GGML_TYPE_IQ4_XS: | ||||
|         lut_size = 4*16; | ||||
|         break; | ||||
|     default: | ||||
|         break; | ||||
|     } | ||||
| 
 | ||||
|     // Needs to be kept up to date on shader changes
 | ||||
|     const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; | ||||
|     const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); | ||||
|  | @ -1388,7 +1414,13 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec | |||
|     const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; | ||||
|     const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; | ||||
| 
 | ||||
|     return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; | ||||
|     const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; | ||||
|     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; | ||||
| 
 | ||||
|     VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " | ||||
|                  "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); | ||||
| 
 | ||||
|     return supported; | ||||
| } | ||||
| 
 | ||||
| static void ggml_vk_load_shaders(vk_device& device) { | ||||
|  | @ -1472,62 +1504,32 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||
|         m_align =  64; | ||||
|         s_align =  32; | ||||
| 
 | ||||
|         // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
 | ||||
|         // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
 | ||||
|         // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
 | ||||
|         // But the numbers happen to work out for 32KB shared memory size that when using the medium
 | ||||
|         // size there's enough room for everything, and we assert for this.
 | ||||
|         uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|         if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|             l_warptile = m_warptile; | ||||
|             l_wg_denoms = m_wg_denoms; | ||||
|             shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); | ||||
|             GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|         if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|             // assert mul_mat_mat_id shaders will fit.
 | ||||
|             GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
| 
 | ||||
|         shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); | ||||
|         if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { | ||||
|             if (device->properties.limits.maxComputeSharedMemorySize == 32768) { | ||||
|                 l_warptile_mmq = m_warptile_mmq; | ||||
|                 l_mmq_wg_denoms = m_mmq_wg_denoms; | ||||
|             } else { | ||||
|                 l_warptile_mmq = s_warptile_mmq; | ||||
|                 l_mmq_wg_denoms = s_mmq_wg_denoms; | ||||
|         for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { | ||||
|             ggml_type t = (ggml_type)i; | ||||
|             // Disable medium and large matrix multiplication if not enough shared memory is available
 | ||||
|             // Check mmq warptiles as the largest configuration
 | ||||
|             // Throw an error if not enough for any matrix multiplication is available
 | ||||
|             if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { | ||||
|                 std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; | ||||
|                 throw std::runtime_error("Shared memory size too small for matrix multiplication."); | ||||
|             } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { | ||||
|                 device->mul_mat_m[i] = false; | ||||
|                 device->mul_mat_l[i] = false; | ||||
|             } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { | ||||
|                 device->mul_mat_l[i] = false; | ||||
|             } | ||||
|             shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); | ||||
|             GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|         if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { | ||||
|             // assert mul_mat_mat_id shaders will fit.
 | ||||
|             GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); | ||||
|         } | ||||
|         // Disable medium and large matrix multiplication if not enough shared memory is available
 | ||||
|         // Check mmq warptiles as the largest configuration
 | ||||
|         // Throw an error if not enough for any matrix multiplication is available
 | ||||
|         if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { | ||||
|             std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; | ||||
|             throw std::runtime_error("Shared memory size too small for matrix multiplication."); | ||||
|         } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { | ||||
|             device->mul_mat_m = false; | ||||
|             device->mul_mat_l = false; | ||||
|         } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { | ||||
|             device->mul_mat_l = false; | ||||
|         } | ||||
| 
 | ||||
|         // Disable mul_mat_id if not enough shared memory is available
 | ||||
|         if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { | ||||
|             device->mul_mat_id_s = false; | ||||
|             device->mul_mat_id_m = false; | ||||
|             device->mul_mat_id_l = false; | ||||
|         } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { | ||||
|             device->mul_mat_id_m = false; | ||||
|             device->mul_mat_id_l = false; | ||||
|         } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { | ||||
|             device->mul_mat_id_l = false; | ||||
|             // Disable mul_mat_id if not enough shared memory is available
 | ||||
|             if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { | ||||
|                 device->mul_mat_id_s[i] = false; | ||||
|                 device->mul_mat_id_m[i] = false; | ||||
|                 device->mul_mat_id_l[i] = false; | ||||
|             } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { | ||||
|                 device->mul_mat_id_m[i] = false; | ||||
|                 device->mul_mat_id_l[i] = false; | ||||
|             } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { | ||||
|                 device->mul_mat_id_l[i] = false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -1684,119 +1686,116 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||
| #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) | ||||
|     if (device->coopmat_support) { | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
| #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true);   \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true);   \ | ||||
| 
 | ||||
|         // Create 2 variants, {f16,f32} accumulator
 | ||||
| #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
| #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->coopmat_acc_f16_support) { \ | ||||
|             CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|             CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         } \ | ||||
|         if (device->coopmat_acc_f32_support) { \ | ||||
|             CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|             CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         } \ | ||||
| 
 | ||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         if (device->coopmat_acc_f16_support) { | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         } else { | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         } | ||||
| 
 | ||||
|         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
 | ||||
|         if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { | ||||
|             CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
| 
 | ||||
|             if (device->coopmat_acc_f16_support) { | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         if (device->coopmat_acc_f16_support) { | ||||
|             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| 
 | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             } else { | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         } else { | ||||
|             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| 
 | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             } | ||||
|             CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         } | ||||
| #undef CREATE_MM2 | ||||
| #undef CREATE_MM | ||||
|  | @ -1804,141 +1803,135 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||
| #endif  // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 | ||||
|     if (device->fp16) { | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
| #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ | ||||
| 
 | ||||
|         // Create 2 variants, {f16,f32} accumulator
 | ||||
| #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
| #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
| 
 | ||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc,  matmul_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
 | ||||
|         if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { | ||||
|             CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         } | ||||
|         CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc,  matmul_id_iq2_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| #undef CREATE_MM2 | ||||
| #undef CREATE_MM | ||||
|     } else { | ||||
|         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 | ||||
| #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
| #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \ | ||||
|         if (device->mul_mat ## ID ## _l) \ | ||||
|         if (device->mul_mat ## ID ## _l[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \ | ||||
|         if (device->mul_mat ## ID ## _m) \ | ||||
|         if (device->mul_mat ## ID ## _m[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \ | ||||
|         if (device->mul_mat ## ID ## _s) \ | ||||
|         if (device->mul_mat ## ID ## _s[TYPE]) \ | ||||
|             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \ | ||||
| 
 | ||||
|         CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); | ||||
| 
 | ||||
|         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
 | ||||
|         if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { | ||||
|             CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| 
 | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         } | ||||
|         CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
|         CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); | ||||
| #undef CREATE_MM | ||||
|     } | ||||
| 
 | ||||
|  | @ -2623,34 +2616,36 @@ static vk_device ggml_vk_get_device(size_t idx) { | |||
| 
 | ||||
|         // Shaders
 | ||||
|         // Disable matmul tile sizes early if performance low or not supported
 | ||||
|         switch (device->vendor_id) { | ||||
|         for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { | ||||
|             switch (device->vendor_id) { | ||||
| #ifndef GGML_VULKAN_RUN_TESTS | ||||
|         case VK_VENDOR_ID_AMD: | ||||
|         case VK_VENDOR_ID_INTEL: | ||||
|             device->mul_mat_l = false; | ||||
|             device->mul_mat_m = true; | ||||
|             device->mul_mat_s = true; | ||||
|             device->mul_mat_id_l = false; | ||||
|             device->mul_mat_id_m = true; | ||||
|             device->mul_mat_id_s = true; | ||||
|             break; | ||||
|         case VK_VENDOR_ID_APPLE: | ||||
|             device->mul_mat_l = false; | ||||
|             device->mul_mat_m = true; | ||||
|             device->mul_mat_s = false; | ||||
|             device->mul_mat_id_l = false; | ||||
|             device->mul_mat_id_m = true; | ||||
|             device->mul_mat_id_s = false; | ||||
|             break; | ||||
|             case VK_VENDOR_ID_AMD: | ||||
|             case VK_VENDOR_ID_INTEL: | ||||
|                 device->mul_mat_l[i] = false; | ||||
|                 device->mul_mat_m[i] = true; | ||||
|                 device->mul_mat_s[i] = true; | ||||
|                 device->mul_mat_id_l[i] = false; | ||||
|                 device->mul_mat_id_m[i] = true; | ||||
|                 device->mul_mat_id_s[i] = true; | ||||
|                 break; | ||||
|             case VK_VENDOR_ID_APPLE: | ||||
|                 device->mul_mat_l[i] = false; | ||||
|                 device->mul_mat_m[i] = true; | ||||
|                 device->mul_mat_s[i] = false; | ||||
|                 device->mul_mat_id_l[i] = false; | ||||
|                 device->mul_mat_id_m[i] = true; | ||||
|                 device->mul_mat_id_s[i] = false; | ||||
|                 break; | ||||
| #endif | ||||
|         default: | ||||
|             device->mul_mat_l = true; | ||||
|             device->mul_mat_m = true; | ||||
|             device->mul_mat_s = true; | ||||
|             device->mul_mat_id_l = true; | ||||
|             device->mul_mat_id_m = true; | ||||
|             device->mul_mat_id_s = true; | ||||
|             break; | ||||
|             default: | ||||
|                 device->mul_mat_l[i] = true; | ||||
|                 device->mul_mat_m[i] = true; | ||||
|                 device->mul_mat_s[i] = true; | ||||
|                 device->mul_mat_id_l[i] = true; | ||||
|                 device->mul_mat_id_m[i] = true; | ||||
|                 device->mul_mat_id_s[i] = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         ggml_vk_load_shaders(device); | ||||
|  | @ -3756,31 +3751,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int | |||
|     return split_k; | ||||
| } | ||||
| 
 | ||||
| static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); | ||||
| static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); | ||||
| 
 | ||||
|     if (ctx->device->coopmat2) { | ||||
|         if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { | ||||
|         if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { | ||||
|             return aligned ? mmp->a_l : mmp->l; | ||||
|         } | ||||
|         if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { | ||||
|         if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { | ||||
|             return aligned ? mmp->a_m : mmp->m; | ||||
|         } | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
| 
 | ||||
|     if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { | ||||
|     if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
|     if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { | ||||
|     if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { | ||||
|         return aligned ? mmp->a_m : mmp->m; | ||||
|     } | ||||
|     return aligned ? mmp->a_l : mmp->l; | ||||
| } | ||||
| 
 | ||||
| static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); | ||||
|     return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align; | ||||
| static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); | ||||
|     return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; | ||||
| } | ||||
| 
 | ||||
| static void ggml_vk_matmul( | ||||
|  | @ -3807,31 +3802,31 @@ static void ggml_vk_matmul( | |||
|     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); | ||||
| } | ||||
| 
 | ||||
| static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); | ||||
| static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); | ||||
| 
 | ||||
|     if (ctx->device->coopmat2) { | ||||
|         if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { | ||||
|         if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { | ||||
|             return aligned ? mmp->a_l : mmp->l; | ||||
|         } | ||||
|         if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { | ||||
|         if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { | ||||
|             return aligned ? mmp->a_m : mmp->m; | ||||
|         } | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
| 
 | ||||
|     if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { | ||||
|     if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { | ||||
|         return aligned ? mmp->a_s : mmp->s; | ||||
|     } | ||||
|     if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { | ||||
|     if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { | ||||
|         return aligned ? mmp->a_m : mmp->m; | ||||
|     } | ||||
|     return aligned ? mmp->a_l : mmp->l; | ||||
| } | ||||
| 
 | ||||
| static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); | ||||
|     return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; | ||||
| static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { | ||||
|     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); | ||||
|     return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; | ||||
| } | ||||
| 
 | ||||
| static void ggml_vk_matmul_id( | ||||
|  | @ -4012,10 +4007,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub | |||
|     const int y_ne = ne11 * ne10; | ||||
|     const int d_ne = ne11 * ne01; | ||||
| 
 | ||||
|     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); | ||||
|     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); | ||||
|     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; | ||||
| 
 | ||||
|     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); | ||||
|     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); | ||||
| 
 | ||||
|     const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); | ||||
| 
 | ||||
|  | @ -4594,10 +4589,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& | |||
|     const uint64_t y_ne = ne11 * ne10; | ||||
|     const uint64_t d_ne = ne21 * ne20; | ||||
| 
 | ||||
|     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); | ||||
|     const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); | ||||
|     const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; | ||||
| 
 | ||||
|     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); | ||||
|     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); | ||||
| 
 | ||||
|     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); | ||||
|     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); | ||||
|  | @ -8036,13 +8031,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | |||
|         case GGML_OP_MUL_MAT: | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
|             { | ||||
|                 ggml_type src0_type = op->src[0]->type; | ||||
|                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; | ||||
|                 const vk_device& device = ggml_vk_get_device(ctx->device); | ||||
|                 if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { | ||||
|                 if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { | ||||
|                     // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
 | ||||
|                     return false; | ||||
|                 } | ||||
|                 switch (op->src[0]->type) { | ||||
|                 switch (src0_type) { | ||||
|                     case GGML_TYPE_F32: | ||||
|                     case GGML_TYPE_F16: | ||||
|                     case GGML_TYPE_Q4_0: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue