diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 5fc1ba4ad..6945b0b51 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -30,9 +30,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a_packed16[a_offset + ib].d); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); - return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d; + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f); } #endif @@ -60,12 +59,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a_packed16[a_offset + ib].d); const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); - return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d; + return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f); } #endif @@ -95,10 +93,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a_packed16[a_offset + ib].d); uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; - return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d; + return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)); } #endif @@ -109,8 +106,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a_packed16[a_offset + ib].d); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); - return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d; + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2d5b8e466..433da8593 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -31,27 +31,13 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_ #if K_PER_ITER == 8 #if QUANT_R == 2 - B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4]; - B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4]; - FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x); - FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x); - FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y); - FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y); - FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z); - FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z); - FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w); - FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w); + const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4]; + const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4]; + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); #else - B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4]; - B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1]; - FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x); - FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y); - FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z); - FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w); - FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x); - FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y); - FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z); - FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w); + const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]); #endif #else // Check if the second of the pair of elements is OOB, and don't fetch B or @@ -71,18 +57,20 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_ const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index #if K_PER_ITER == 8 + // TODO: can we dequant as f16 instead of as vec? const vec4 v = dequantize4(ib, iqs, a_offset); const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + FLOAT_TYPE rowtmp = 0; // matrix multiplication - temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]); - temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]); - temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]); - temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]); - temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]); - temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]); - temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]); - temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]); + rowtmp += dot(bv0, v); + rowtmp += dot(bv1, v2); + +#if !defined(DATA_A_Q4_1) && !defined(DATA_A_Q5_1) + const float d = float(data_a[a_offset + ib].d); + rowtmp *= d; +#endif + temp[n] += rowtmp; #else const vec2 v = dequantize(ib, iqs, a_offset);