Simplify context use, optimize matmul shader for warp size 64 (AMD GCN), fix split_k matmul shader optimization

This commit is contained in:
0cc4m 2024-01-27 21:36:48 +01:00
parent a5cca6cd8c
commit 48ad459efc
3 changed files with 20746 additions and 20788 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -243,8 +243,6 @@ mulmat_head = """#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#define WARP 32
#ifndef LOAD_VEC
#define LOAD_VEC 1
#endif
@ -266,7 +264,6 @@ layout (push_constant) uniform parameter
uint stride_b;
uint stride_d;
uint k_split;
uint d_offset;
uint ne02;
uint ne12;
@ -286,6 +283,7 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)];
shared FLOAT_TYPE buf_b[BN * (BK+1)];
@ -299,9 +297,9 @@ void main() {
const uint batch_idx_a = i03 * p.ne02 + i02;
const uint blocks_x = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_x;
const uint ik = gl_WorkGroupID.x / blocks_x;
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
const uint warp_i = gl_LocalInvocationID.x / WARP;
@ -354,7 +352,7 @@ void main() {
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc + l < p.M && block + loadr < p.K) {
if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
} else {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
@ -379,7 +377,7 @@ void main() {
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
#else
if (ic * BN + loadc + l < p.N && block + loadr < p.K) {
if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
} else {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
@ -422,7 +420,7 @@ void main() {
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
const uint k_split_offset = ik * p.M * p.N;
const uint offsets = gl_GlobalInvocationID.z * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -432,7 +430,7 @@ void main() {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[p.d_offset + gl_GlobalInvocationID.z * p.batch_stride_d + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
}
}
@ -443,7 +441,9 @@ void main() {
mulmat_split_k_reduce_src = """#version 450
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
@ -451,7 +451,6 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint ne;
uint k_num;
uint d_offset;
} p;
void main() {
@ -463,11 +462,11 @@ void main() {
float result = 0.0f;
for (int i = 0; i < p.k_num; i++) {
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx];
}
data_d[p.d_offset + idx] = result;
data_d[idx] = result;
}
"""