diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp index a8d35ac38..2ec5d1eb6 100644 --- a/ggml-vulkan-shaders.hpp +++ b/ggml-vulkan-shaders.hpp @@ -12,11 +12,166 @@ const std::string shader_int8_ext = R"( #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require )"; +// Type-specific defines +const std::string shader_f16_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 + +#define A_TYPE float16_t +)"; +const std::string shader_q4_0_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; + +#define A_TYPE block_q4_0 +)"; +const std::string shader_q4_1_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +#define A_TYPE block_q4_1 +)"; +const std::string shader_q5_0_defines = R"( +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#define QUANT_K 32 +#define QUANT_R 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +#define A_TYPE block_q5_0 +)"; +const std::string shader_q5_1_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +#define A_TYPE block_q5_1 +)"; +const std::string shader_q8_0_defines = R"( +#define QUANT_K 32 +#define QUANT_R 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; + +#define A_TYPE block_q8_0 +)"; + +// Dequant functions +const std::string shader_f16_dequant_func = R"( +#define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]); +)"; +const std::string shader_f16_dequant_func_compat = R"( +#define DEQUANT_FUNC vec2 v = vec2(x[ib + 0], x[ib + 1]); +)"; + +const std::string shader_q4_0_dequant_func = R"( +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +const uint8_t vui = x[ib].qs[iqs]; \ +f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \ +v = (v - 8.0hf)*d; +)"; +const std::string shader_q4_0_dequant_func_compat = R"( +#define DEQUANT_FUNC const float d = float(x[ib].d); \ +const uint vui = uint(x[ib].qs[iqs]); \ +vec2 v = vec2(vui & 0xF, vui >> 4); \ +v = (v - 8.0f)*d; +)"; + +const std::string shader_q4_1_dequant_func = R"( +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +const float16_t m = x[ib].m; \ +const uint8_t vui = x[ib].qs[iqs]; \ +f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \ +v = v*d + m; +)"; +const std::string shader_q4_1_dequant_func_compat = R"( +#define DEQUANT_FUNC const float d = float(x[ib].d); \ +const float m = float(x[ib].m); \ +const uint vui = uint(x[ib].qs[iqs]); \ +vec2 v = vec2(vui & 0xF, vui >> 4); \ +v = v*d + m; +)"; + +const std::string shader_q5_0_dequant_func = R"( +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \ +const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ +const uint8_t vui = x[ib].qs[iqs]; \ +f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ +v = (v - 16.0hf) * d; +)"; +const std::string shader_q5_0_dequant_func_compat = R"( +#define DEQUANT_FUNC const float d = float(x[ib].d); \ +const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \ +const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ +const uint vui = uint(x[ib].qs[iqs]); \ +vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ +v = (v - 16.0f) * d; +)"; + +const std::string shader_q5_1_dequant_func = R"( +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +const float16_t m = x[ib].m; \ +const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \ +const uint8_t vui = x[ib].qs[iqs]; \ +f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ +v = v*d + m; +)"; +const std::string shader_q5_1_dequant_func_compat = R"( +#define DEQUANT_FUNC const float d = float(x[ib].d); \ +const float m = float(x[ib].m); \ +const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \ +const uint vui = uint(x[ib].qs[iqs]); \ +vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ +v = v*d + m; +)"; + +const std::string shader_q8_0_dequant_func = R"( +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +f16vec2 v = f16vec2(x[ib].qs[iqs], x[ib].qs[iqs + 1]); \ +v = v * d; +)"; +const std::string shader_q8_0_dequant_func_compat = R"( +#define DEQUANT_FUNC const float d = float(x[ib].d); \ +vec2 v = vec2(int(x[ib].qs[iqs]), int(x[ib].qs[iqs + 1])); \ +v = v * d; +)"; + // MULMAT const std::string mulmat_head = R"( #version 450 +#extension GL_EXT_scalar_block_layout : require #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require @@ -30,7 +185,7 @@ const std::string mulmat_head = R"( const std::string mulmat_body = R"( layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A { A_TYPE data_a[]; }; +layout (binding = 0, scalar) readonly buffer A { A_TYPE data_a[]; }; layout (binding = 1) readonly buffer B { B_TYPE data_b[]; }; layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; }; @@ -238,28 +393,16 @@ void main() { const std::string dequant_head = R"( #version 450 +#extension GL_EXT_scalar_block_layout : require #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require )"; -const std::string dequant_q4_0_defines = R"( -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -#define A_TYPE block_q4_0 -)"; - const std::string dequant_body = R"( layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 0, scalar) readonly buffer A { A_TYPE x[]; }; layout (binding = 1) writeonly buffer D { D_TYPE y[]; }; layout (push_constant) uniform parameter @@ -283,15 +426,16 @@ void main() { const int stride_a = p.stride_a / QUANT_K; - const int idx = col * stride_a + row; - const FLOAT_TYPE d = FLOAT_TYPE(x[idx].d); + const int ib = col * stride_a + row; - [[unroll]] for (int j = 0; j < QUANT_K/2; ++j) { - const FLOAT_TYPE x0 = FLOAT_TYPE((x[idx].qs[j] & 0x0F) - 8); - const FLOAT_TYPE x1 = FLOAT_TYPE((x[idx].qs[j] >> 4) - 8); + const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + const int step = QUANT_R == 1 ? 2 : 1; - y[col * p.stride_b + row*QUANT_K + j + 0 ] = D_TYPE(x0*d); - y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = D_TYPE(x1*d); + [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) { + DEQUANT_FUNC + + y[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x); + y[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y); } } )"; @@ -300,64 +444,16 @@ void main() { const std::string mul_mat_vec_head = R"( #version 450 +#extension GL_EXT_scalar_block_layout : require #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require )"; -const std::string mul_mat_vec_f16_defines = R"( -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -#define A_TYPE float16_t -)"; - -const std::string mul_mat_vec_f16_dequant_func = R"( -#define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \ -float16_t v1 = x[ib + 1]; -)"; - -const std::string mul_mat_vec_f16_dequant_func_compat = R"( -#define DEQUANT_FUNC float v0 = float(x[ib + 0]); \ -float v1 = float(x[ib + 1]); -)"; - -const std::string mul_mat_vec_q4_0_defines = R"( -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; -#define A_TYPE block_q4_0 -)"; - -const std::string mul_mat_vec_q4_0_dequant_func = R"( -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -const uint8_t vui = x[ib].qs[iqs]; \ -const int8_t vi0 = int8_t(vui & 0xF); \ -const int8_t vi1 = int8_t(vui >> 4); \ -float16_t v0 = float16_t(vi0 - 8)*d; \ -float16_t v1 = float16_t(vi1 - 8)*d; -)"; - -const std::string mul_mat_vec_q4_0_dequant_func_compat = R"( -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -const uint vui = uint(x[ib].qs[iqs]); \ -const int vi0 = int(vui) & 0xF; \ -const int vi1 = int(vui) >> 4; \ -float v0 = float(vi0 - 8)*d; \ -float v1 = float(vi1 - 8)*d; -)"; - const std::string mul_mat_vec_body = R"( -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 0, scalar) readonly buffer A { A_TYPE x[]; }; layout (binding = 1) readonly buffer B { B_TYPE y[]; }; layout (binding = 2) writeonly buffer D { D_TYPE dst[]; }; @@ -366,14 +462,14 @@ layout (push_constant) uniform parameter int ncols; } p; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +shared FLOAT_TYPE tmp[QUANT_K]; void main() { const int block_size = int(gl_WorkGroupSize.x); const int row = int(gl_WorkGroupID.x); const int tid = int(gl_LocalInvocationID.x); - const int y_offset = QUANT_K/2; + const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; tmp[tid] = FLOAT_TYPE(0.0f); @@ -386,8 +482,8 @@ void main() { DEQUANT_FUNC // matrix multiplication - tmp[tid] += FLOAT_TYPE(v0) * FLOAT_TYPE(y[iybs + iqs + 0]); - tmp[tid] += FLOAT_TYPE(v1) * FLOAT_TYPE(y[iybs + iqs + y_offset]); + tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(y[iybs + iqs + 0]); + tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(y[iybs + iqs + y_offset]); } // sum up partial sums and write back result diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index ebe355ce8..d45d9ee26 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -52,6 +52,8 @@ #define VK_SUBMIT_BATCH 3 +#define VK_NUM_TYPES 16 + typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); struct vk_buffer { @@ -157,12 +159,12 @@ vk_pipeline vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m, vk_pipeline vk_pipeline_matmul_f16_f32_l, vk_pipeline_matmul_f16_f32_m, vk_pipeline_matmul_f16_f32_s; vk_pipeline vk_pipeline_matmul_f16_f32_aligned_l, vk_pipeline_matmul_f16_f32_aligned_m, vk_pipeline_matmul_f16_f32_aligned_s; vk_pipeline vk_pipeline_matmul_split_k_reduce; -vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0; -vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16_f32, vk_pipeline_dequant_mul_mat_vec_q4_0_f32; +vk_pipeline vk_pipeline_dequant[VK_NUM_TYPES]; +vk_pipeline vk_pipeline_dequant_mul_mat_vec[VK_NUM_TYPES]; +vk_pipeline vk_pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES]; vk_pipeline vk_pipeline_mul_f32; vk_pipeline vk_pipeline_add_f32, vk_pipeline_add_f16_f32_f16; vk_pipeline vk_pipeline_scale_f32; -vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0; static std::vector> vk_pinned_memory; @@ -651,6 +653,31 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { } } +static inline bool ggml_vk_build_shader_type_defines(std::stringstream& stream, ggml_type type, bool compat) { + switch(type) { + case GGML_TYPE_F16: + stream << shader_f16_defines << (compat ? shader_f16_dequant_func_compat : shader_f16_dequant_func); + return true; + case GGML_TYPE_Q4_0: + stream << shader_q4_0_defines << (compat ? shader_q4_0_dequant_func_compat : shader_q4_0_dequant_func); + return true; + case GGML_TYPE_Q4_1: + stream << shader_q4_1_defines << (compat ? shader_q4_1_dequant_func_compat : shader_q4_1_dequant_func); + return true; + case GGML_TYPE_Q5_0: + stream << shader_q5_0_defines << (compat ? shader_q5_0_dequant_func_compat : shader_q5_0_dequant_func); + return true; + case GGML_TYPE_Q5_1: + stream << shader_q5_1_defines << (compat ? shader_q5_1_dequant_func_compat : shader_q5_1_dequant_func); + return true; + case GGML_TYPE_Q8_0: + stream << shader_q8_0_defines << (compat ? shader_q8_0_dequant_func_compat : shader_q8_0_dequant_func); + return true; + default: + return false; + } +} + static void ggml_vk_generate_shaders() { #ifdef VK_DEBUG std::cerr << "ggml_vk_generate_shaders()" << std::endl; @@ -705,65 +732,46 @@ static void ggml_vk_generate_shaders() { vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - // Build dequant q4_0 - stream.str(""); - stream.clear(); + // Build dequant shaders + vk_pipeline_dequant[GGML_TYPE_F32] = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); - stream << dequant_head << shader_float_type << dequant_q4_0_defines << dequant_body; + for (int i = 0; i < VK_NUM_TYPES; i++) { + stream.str(""); + stream.clear(); - vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + stream << dequant_head << shader_float_type; + if (vk_device.fp16) { + stream << shader_int8_ext; + } + + if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) { + continue; + } + + stream << dequant_body; + + vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + } // mul mat vec - stream.str(""); - stream.clear(); + for (int i = 0; i < VK_NUM_TYPES; i++) { + stream.str(""); + stream.clear(); - stream << mul_mat_vec_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func; - } else { - stream << mul_mat_vec_q4_0_dequant_func_compat; + stream << mul_mat_vec_head << shader_float_type; + if (vk_device.fp16) { + stream << shader_int8_ext; + } + + if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) { + continue; + } + + stream << mul_mat_vec_body; + + vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); } - stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body; - - vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - - stream.str(""); - stream.clear(); - - stream << mul_mat_vec_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func; - } else { - stream << mul_mat_vec_q4_0_dequant_func_compat; - } - stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body; - - vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - - stream.str(""); - stream.clear(); - - stream << mul_mat_vec_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext << mul_mat_vec_f16_dequant_func; - } else { - stream << mul_mat_vec_f16_dequant_func_compat; - } - stream << mul_mat_vec_f16_defines << mul_mat_vec_body; - - vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - - stream.str(""); - stream.clear(); - - stream << mul_mat_vec_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext << mul_mat_vec_f16_dequant_func; - } else { - stream << mul_mat_vec_f16_dequant_func_compat; - } - stream << mul_mat_vec_f16_defines << mul_mat_vec_body; - vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); // add stream.str(""); @@ -779,7 +787,6 @@ static void ggml_vk_generate_shaders() { // Static shaders vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_string("split_k_reduce", mulmat_split_k_reduce_src, {}, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); - vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_string("scale_f32", scale_src, { "X_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); @@ -994,32 +1001,42 @@ void ggml_vk_init(void) { #endif } -static vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { +static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { #ifdef VK_DEBUG std::cerr << "ggml_vk_get_to_fp16()" << std::endl; #endif switch (type) { - case GGML_TYPE_Q4_0: - return &vk_pipeline_dequant_q4_0; case GGML_TYPE_F32: - return &vk_pipeline_f32_to_f16; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; default: return nullptr; } + + return &vk_pipeline_dequant[type]; } -static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) { +static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) { #ifdef VK_DEBUG std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl; #endif switch (type) { - case GGML_TYPE_Q4_0: - return f16_y ? &vk_pipeline_dequant_mul_mat_vec_q4_0 : &vk_pipeline_dequant_mul_mat_vec_q4_0_f32; case GGML_TYPE_F16: - return f16_y ? &vk_pipeline_dequant_mul_mat_vec_f16 : &vk_pipeline_dequant_mul_mat_vec_f16_f32; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; default: return nullptr; } + + return f16_y ? &vk_pipeline_dequant_mul_mat_vec[type] : &vk_pipeline_dequant_mul_mat_vec_f32[type]; } // buffer pool for vulkan