better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations

This commit is contained in:
Eve 2025-01-07 19:57:55 -05:00
parent 6f5d62b098
commit 91f1d9ce99

View file

@ -8,7 +8,73 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row)
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
const uint32_t qh4_u32 = (qh_u32 & 0x30303030);
const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
const vec4 q0 = vec4(unpack8(q0_u32)) - 32;
const vec4 q1 = vec4(unpack8(q1_u32)) - 32;
const vec4 q2 = vec4(unpack8(q2_u32)) - 32;
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
if (all_threads) {
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
[[unroll]] for (uint l = 0; l < 4; ++l) {
sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]);
sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]);
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
}
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
}
}
}
void compute_outputs(const uint first_row, const uint num_rows) {
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
@ -31,65 +97,19 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint s_offset = 8*v_im + is; const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0); temp[j][i] = FLOAT_TYPE(0);
} }
} }
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint nbr_par_th = num_blocks_per_row%it_size;
const uint y_idx = i * QUANT_K + y_offset; const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
uint32_t qh4_u32 = (qh_u32 & 0x30303030);
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
uvec4 q0 = uvec4(unpack8(q0_u32));
uvec4 q1 = uvec4(unpack8(q1_u32));
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
[[unroll]] for (uint l = 0; l < 4; ++l) {
sum[0] = fma(FLOAT_TYPE(by0[l]), FLOAT_TYPE(int8_t(q0[l]) - 32), sum[0]);
sum[1] = fma(FLOAT_TYPE(by32[l]), FLOAT_TYPE(int8_t(q1[l]) - 32), sum[1]);
sum[2] = fma(FLOAT_TYPE(by64[l]), FLOAT_TYPE(int8_t(q2[l]) - 32), sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), FLOAT_TYPE(int8_t(q3[l]) - 32), sum[3]);
}
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
}
}
} }
calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid); reduce_result(temp, d_offset, first_row, num_rows, tid);
} }