diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp index cf52927b3..0ce57f30c 100644 --- a/ggml-vulkan-shaders.hpp +++ b/ggml-vulkan-shaders.hpp @@ -85,6 +85,20 @@ struct block_q8_0 #define A_TYPE block_q8_0 )"; +const std::string shader_q6_K_defines = R"( +#define QUANT_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K/2]; + uint8_t qh[QUANT_K/4]; + int8_t scales[QUANT_K/16]; + float16_t d; +}; + +#define A_TYPE block_q6_K +)"; + // Dequant functions const std::string shader_f16_dequant_func = R"( #define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]); @@ -438,6 +452,42 @@ void main() { } )"; +// K-quants +const std::string dequant_q6_K_body = R"( +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 1) writeonly buffer D { D_TYPE y[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int K; + int stride_a; + int stride_b; +} p; + +void main() { + const int i = int(gl_WorkGroupID.x); + const int tid = int(gl_LocalInvocationID.x); + const int ip = tid / 32; + const int il = tid - 32 * ip; + const int is = 8 * ip + il / 16; + + const int y_idx = i * QUANT_K + 128 * ip + il; + + const int ql_idx = 64 * ip + il; + const uint8_t qh = x[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(x[i].d); + + y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); +} +)"; + // Mul Mat Vec const std::string mul_mat_vec_head = R"( #version 450 @@ -497,6 +547,90 @@ void main() { } )"; +const std::string mul_mat_vec_q6_K_body = R"( +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 1) readonly buffer B { B_TYPE y[]; }; +layout (binding = 2) writeonly buffer D { D_TYPE dst[]; }; + +layout (push_constant) uniform parameter +{ + int ncols; +} p; + +shared FLOAT_TYPE tmp[32]; + +void main() { + const int row = int(gl_WorkGroupID.x); + + const int num_blocks_per_row = p.ncols / QUANT_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int v_in = tid - step*v_im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + + const int ql_offset = 64*v_im + l0; + const int qh_offset = 32*v_im + l0; + const int s_offset = 8*v_im + is; + const int y_offset = 128*v_im + l0; + + tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int y_idx = i * QUANT_K + y_offset; + + const FLOAT_TYPE d = x[ib0 + i].d; + +#if K_QUANTS_PER_ITERATION == 1 + FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) + + FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); + tmp[16 * ix + tid] += sum; +#else + FLOAT_TYPE sum = 0; + for (int l = 0; l < 4; ++l) { + sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); + } + tmp[16 * ix + tid] += sum; +#endif + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = 16; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + if (tid == 0) { + dst[row] = D_TYPE(tmp[0]); + } +} +)"; + // F16 to F32 const std::string f32_to_f16_src = R"( #version 450 diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index cf42fa768..19b3dd397 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -54,6 +54,12 @@ #define VK_NUM_TYPES 16 +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 1 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); struct vk_buffer { @@ -714,6 +720,9 @@ static inline bool ggml_vk_build_shader_type_defines(std::stringstream& stream, case GGML_TYPE_Q8_0: stream << shader_q8_0_defines << (compat ? shader_q8_0_dequant_func_compat : shader_q8_0_dequant_func); return true; + case GGML_TYPE_Q6_K: + stream << shader_q6_K_defines; + return true; default: return false; } @@ -789,9 +798,20 @@ static void ggml_vk_generate_shaders() { continue; } - stream << dequant_body; + int work_group_denom; - vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + switch ((ggml_type)i) { + case GGML_TYPE_Q6_K: + stream << dequant_q6_K_body; + work_group_denom = 64 * 4; + break; + default: + stream << dequant_body; + work_group_denom = 256 * 32; + break; + } + + vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {work_group_denom, 1, 1}, {}, 1); } // mul mat vec @@ -808,10 +828,17 @@ static void ggml_vk_generate_shaders() { continue; } - stream << mul_mat_vec_body; + switch ((ggml_type)i) { + case GGML_TYPE_Q6_K: + stream << mul_mat_vec_q6_K_body; + break; + default: + stream << mul_mat_vec_body; + break; + } - vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float", "K_QUANTS_PER_ITERATION", std::to_string(K_QUANTS_PER_ITERATION) }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); } // add @@ -1049,6 +1076,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -1068,6 +1096,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -1205,7 +1234,7 @@ void ggml_vk_host_free(void* ptr) { } } if (buf == nullptr) { - fprintf(stderr, "WARNING: to free pinned memory: memory not in map\n"); + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); return; }