diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5c0d96e49..c694f3087 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1642,10 +1642,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - //CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - //CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - //CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - //CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) @@ -1659,10 +1659,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - //CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - //CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - //CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - //CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 175e31fa7..5ff55f8e5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -301,6 +301,130 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint8_t qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return ret; +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib4 = (idx & 0xFC) >> 4; // 0..63 + const uint is16 = QUANT_K / 8 + 2 * ib32; // index in packed16 + + const uint8_t qs = bl.block.qs[ib4]; + const uint signscale = pack32(u16vec2(bl16.block.qs[is16], bl16.block.qs[is16+1])); + + const float16_t dscale = bl.block.d * 0.5hf * (0.5hf + float16_t(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib4 & 3), 7); + sign |= bitCount(sign) << 7; + + const uint8_t g = unpack8(iq3xxs_grid[qs])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return ret; +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; // 0..7 + const uint qhbit = iqs & 7; + const uint isgn = (idx & 0xF8) >> 3; // 0..31 + const uint is = (idx & 0xC0) >> 6; // 0..3 + + const uint8_t scale = (bl.block.scales[is] >> ((idx & 0x20) >> 3)) & uint8_t(0xF); + const float16_t dscale = d * (1.0hf + float16_t(2 * scale)); + + const uint qs = bl.block.qs[iqs]; + const uint qh = (bl.block.qh[iqh] << (8 - qhbit)) & 0x100; + const uint8_t sign = bl.block.signs[isgn]; + + const uint g = unpack8(iq3s_grid[qs | qh])[idx & 3]; + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + + return ret; +} +#endif + + #if defined(DATA_A_IQ4_NL) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { block_iq4_nl block; @@ -340,6 +464,14 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ5_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8b5c186b7..52cf95363 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -314,10 +314,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); for (const auto& tname : type_names) { - if (tname == "iq2_xs" && coopmat2) continue; - if (tname == "iq2_xxs" && coopmat2) continue; - if (tname == "iq3_xxs" && coopmat2) continue; - if (tname == "iq3_s" && coopmat2) continue; std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; @@ -373,9 +369,6 @@ void process_shaders() { if (tname == "f32") { continue; } - if (tname == "iq3_s" || tname == "iq3_xxs" || tname == "iq2_xs" || tname == "iq2_xxs") { - continue; - } if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",