diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index ed924a072..ae78a2d14 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 54085f90e..cbe3269d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint itid8 = itid%8; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); @@ -47,6 +47,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32); barrier(); + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in])); + const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8])); + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2]; @@ -60,14 +71,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4), + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4), + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 28cde16db..f8bc885dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 9039ce1f3..72b36d9af 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 231bfc5da..efbdfe939 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -87,7 +87,7 @@ void compute_outputs(const uint first_row, const uint num_rows) { const uint ix = tid/16; const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - 8*v_im; // 0...15 or 0...7 + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4;