dot and delta optimization

This commit is contained in:
Eve 2024-12-03 11:57:37 -05:00
parent 642330ac7c
commit b7ad234517
2 changed files with 20 additions and 36 deletions

View file

@ -30,9 +30,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
} }
vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d; return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f);
} }
#endif #endif
@ -60,12 +59,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
} }
vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d; return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f);
} }
#endif #endif
@ -95,10 +93,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d; return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
} }
vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d; return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF));
} }
#endif #endif
@ -109,8 +106,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
} }
vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d; return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]);
} }
#endif #endif

View file

@ -31,27 +31,13 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
#if K_PER_ITER == 8 #if K_PER_ITER == 8
#if QUANT_R == 2 #if QUANT_R == 2
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4]; const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4]; const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x); const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
#else #else
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4]; const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1]; const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
#endif #endif
#else #else
// Check if the second of the pair of elements is OOB, and don't fetch B or // Check if the second of the pair of elements is OOB, and don't fetch B or
@ -71,18 +57,20 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
#if K_PER_ITER == 8 #if K_PER_ITER == 8
// TODO: can we dequant as f16 instead of as vec?
const vec4 v = dequantize4(ib, iqs, a_offset); const vec4 v = dequantize4(ib, iqs, a_offset);
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
FLOAT_TYPE rowtmp = 0;
// matrix multiplication // matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]); rowtmp += dot(bv0, v);
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]); rowtmp += dot(bv1, v2);
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]); #if !defined(DATA_A_Q4_1) && !defined(DATA_A_Q5_1)
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]); const float d = float(data_a[a_offset + ib].d);
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]); rowtmp *= d;
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]); #endif
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]); temp[n] += rowtmp;
#else #else
const vec2 v = dequantize(ib, iqs, a_offset); const vec2 v = dequantize(ib, iqs, a_offset);