Simplify context use, optimize matmul shader for warp size 64 (AMD GCN), fix split_k matmul shader optimization
This commit is contained in:
parent
a5cca6cd8c
commit
48ad459efc
3 changed files with 20746 additions and 20788 deletions
40909
ggml-vulkan-shaders.hpp
40909
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
598
ggml-vulkan.cpp
598
ggml-vulkan.cpp
File diff suppressed because it is too large
Load diff
|
@ -243,8 +243,6 @@ mulmat_head = """#version 450
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
#define WARP 32
|
|
||||||
|
|
||||||
#ifndef LOAD_VEC
|
#ifndef LOAD_VEC
|
||||||
#define LOAD_VEC 1
|
#define LOAD_VEC 1
|
||||||
#endif
|
#endif
|
||||||
|
@ -266,7 +264,6 @@ layout (push_constant) uniform parameter
|
||||||
uint stride_b;
|
uint stride_b;
|
||||||
uint stride_d;
|
uint stride_d;
|
||||||
uint k_split;
|
uint k_split;
|
||||||
uint d_offset;
|
|
||||||
|
|
||||||
uint ne02;
|
uint ne02;
|
||||||
uint ne12;
|
uint ne12;
|
||||||
|
@ -286,6 +283,7 @@ layout (constant_id = 5) const uint WN = 32;
|
||||||
layout (constant_id = 6) const uint WMITER = 2;
|
layout (constant_id = 6) const uint WMITER = 2;
|
||||||
layout (constant_id = 7) const uint TM = 4;
|
layout (constant_id = 7) const uint TM = 4;
|
||||||
layout (constant_id = 8) const uint TN = 2;
|
layout (constant_id = 8) const uint TN = 2;
|
||||||
|
layout (constant_id = 9) const uint WARP = 32;
|
||||||
|
|
||||||
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||||
|
@ -299,9 +297,9 @@ void main() {
|
||||||
|
|
||||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||||
|
|
||||||
const uint blocks_x = (p.M + BM - 1) / BM;
|
const uint blocks_m = (p.M + BM - 1) / BM;
|
||||||
const uint ir = gl_WorkGroupID.x % blocks_x;
|
const uint ir = gl_WorkGroupID.x % blocks_m;
|
||||||
const uint ik = gl_WorkGroupID.x / blocks_x;
|
const uint ik = gl_WorkGroupID.x / blocks_m;
|
||||||
const uint ic = gl_WorkGroupID.y;
|
const uint ic = gl_WorkGroupID.y;
|
||||||
|
|
||||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||||
|
@ -354,7 +352,7 @@ void main() {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
|
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||||
#else
|
#else
|
||||||
if (ir * BM + loadc + l < p.M && block + loadr < p.K) {
|
if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
|
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
|
||||||
} else {
|
} else {
|
||||||
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||||
|
@ -379,7 +377,7 @@ void main() {
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
|
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
|
||||||
#else
|
#else
|
||||||
if (ic * BN + loadc + l < p.N && block + loadr < p.K) {
|
if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
|
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
|
||||||
|
@ -422,7 +420,7 @@ void main() {
|
||||||
const uint dr = ir * BM + warp_r * WM;
|
const uint dr = ir * BM + warp_r * WM;
|
||||||
const uint dc = ic * BN + warp_c * WN;
|
const uint dc = ic * BN + warp_c * WN;
|
||||||
|
|
||||||
const uint k_split_offset = ik * p.M * p.N;
|
const uint offsets = gl_GlobalInvocationID.z * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||||
|
|
||||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||||
|
@ -432,7 +430,7 @@ void main() {
|
||||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||||
data_d[p.d_offset + gl_GlobalInvocationID.z * p.batch_stride_d + k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -443,7 +441,9 @@ void main() {
|
||||||
|
|
||||||
mulmat_split_k_reduce_src = """#version 450
|
mulmat_split_k_reduce_src = """#version 450
|
||||||
|
|
||||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
|
@ -451,7 +451,6 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ne;
|
uint ne;
|
||||||
uint k_num;
|
uint k_num;
|
||||||
uint d_offset;
|
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
@ -463,11 +462,11 @@ void main() {
|
||||||
|
|
||||||
float result = 0.0f;
|
float result = 0.0f;
|
||||||
|
|
||||||
for (int i = 0; i < p.k_num; i++) {
|
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
||||||
result += data_a[i * p.ne + idx];
|
result += data_a[i * p.ne + idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[p.d_offset + idx] = result;
|
data_d[idx] = result;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue