vulkan: further optimize mul_mat_vec using larger loads (#10387)
* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec. Add some early returns for nonexistent rows in mul_mat_vec shaders. These can only be hit when dispatching a 2D grid of workgroups. Fix the logic for the 2D grid of workgroups to round up. Enable the pipeline robustness extension if it's available, and use it to disable robustness for these pipelines. The instructions to do the bounds checking contend for the same ALU resources as the bit twiddling dequant instructions. * vulkan: Add GLSL structure aliases for quant types to allow larger loads In Vulkan it's not possible to cast pointer types, so instead you have to declare an aliased binding for the memory with a different type. This commit adds aliases for the quant formats using 16b ints, and in a few places where the struct size is a multiple of 4 also using 32b ints. Currently only q4_k's aliases are used, but others will be used in subsequent commits. * vulkan: use larger loads in q5_k and q6_k shaders. Similar to the optimization I did in q4_k recently, this vectorizes some loads and reduces the number of bit twiddling instructions. * vulkan: use larger K step per iteration in mul_mat_vec. Add vec4 dequantization functions, and use them to do K=8 per iteration in mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B which helps reduce the load on the memory system. The K_PER_ITER==2 logic is still there, just for F16/F32, and really only because they support unaligned sizes. Tweak the num_iters/unrolling logic to be simpler and catch a couple missed unrolling opportunities.
This commit is contained in:
parent
ad21c9e1f1
commit
1bacb9f625
11 changed files with 457 additions and 147 deletions
|
@ -3,7 +3,7 @@
|
|||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -12,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||
#define K_PER_ITER 8
|
||||
#else
|
||||
#define K_PER_ITER 2
|
||||
#endif
|
||||
|
||||
|
||||
uint a_offset, b_offset, d_offset, y_offset;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
{
|
||||
const uint col = i*BLOCK_SIZE + 2*tid;
|
||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
|
||||
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
|
||||
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
|
||||
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
|
||||
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
|
||||
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
|
||||
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
|
||||
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
|
||||
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
|
||||
#else
|
||||
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
|
||||
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
|
||||
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
|
||||
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
|
||||
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
|
||||
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
|
||||
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
|
||||
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
|
||||
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
|
||||
#endif
|
||||
#else
|
||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||
// quantized formats since they'll be within the same block. We should
|
||||
|
@ -34,9 +66,24 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
if (!OOB) {
|
||||
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
||||
}
|
||||
#endif
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
|
||||
#else
|
||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
|
@ -44,6 +91,7 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
if (!OOB) {
|
||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,22 +109,33 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
temp[i] = FLOAT_TYPE(0);
|
||||
}
|
||||
|
||||
const int unroll_count = 8;
|
||||
|
||||
const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
|
||||
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i, false);
|
||||
i += 2;
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i, true);
|
||||
i += 2;
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
|
||||
i++;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -106,6 +165,9 @@ void main() {
|
|||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue