diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp index 0ce57f30c..a7cb149ca 100644 --- a/ggml-vulkan-shaders.hpp +++ b/ggml-vulkan-shaders.hpp @@ -468,23 +468,28 @@ layout (push_constant) uniform parameter } p; void main() { - const int i = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); - const int ip = tid / 32; - const int il = tid - 32 * ip; - const int is = 8 * ip + il / 16; + for (int wgy = 0; wgy < 256; wgy++) { + const int i = int(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + const int tid = int(gl_LocalInvocationID.x); + const int ip = tid / 32; + const int il = tid - 32 * ip; + const int is = 8 * ip + il / 16; - const int y_idx = i * QUANT_K + 128 * ip + il; + const int y_idx = i * QUANT_K + 128 * ip + il; - const int ql_idx = 64 * ip + il; - const uint8_t qh = x[i].qh[32 * ip + il]; + const int ql_idx = 64 * ip + il; + const uint8_t qh = x[i].qh[32 * ip + il]; - const FLOAT_TYPE d = FLOAT_TYPE(x[i].d); + const FLOAT_TYPE d = FLOAT_TYPE(x[i].d); - y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); - y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); - y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); - y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } } )"; diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 63fd71d6c..683add489 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -797,20 +797,16 @@ static void ggml_vk_generate_shaders() { continue; } - int work_group_denom; - switch ((ggml_type)i) { case GGML_TYPE_Q6_K: stream << dequant_q6_K_body; - work_group_denom = 64 * 4; break; default: stream << dequant_body; - work_group_denom = 256 * 32; break; } - vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1); + vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1); } // mul mat vec @@ -891,6 +887,8 @@ void ggml_vk_init(void) { }; validation_features.setPNext(nullptr); instance_create_info.setPNext(&validation_features); + +std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; #endif vk_instance = vk::createInstance(instance_create_info);