vulkan: fix coopmat2 validation failures (#11284)
mul mat and flash attention shaders were loading f32 types directly into A/B matrices, which happens to work but is technically invalid usage. For FA, we can load it as an Accumulator matrix and convert and this is not in the inner loop and is cheap enough. For mul mat, it's more efficient to do this conversion in a separate pass and have the input(s) be f16. coopmat2 requires SPIR-V 1.6 (related using to LocalSizeId). LocalSizeId requires maintenance4 be enabled, and SPIR-V 1.6 requires Vulkan 1.3.
This commit is contained in:
parent
9f7add1cde
commit
aea8ddd516
4 changed files with 53 additions and 56 deletions
|
@ -166,7 +166,7 @@ void main() {
|
|||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
|
|
|
@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|||
|
||||
#if QUANT_K > 1
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
#define MAT_A_TYPE float16_t
|
||||
|
||||
#include "dequant_funcs_cm2.comp"
|
||||
|
||||
#else
|
||||
#define DECODEFUNCA
|
||||
#define MAT_A_TYPE A_TYPE
|
||||
#endif
|
||||
|
||||
#define MAT_B_TYPE B_TYPE
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
|
@ -236,16 +232,13 @@ void main() {
|
|||
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
} else
|
||||
#endif // !defined(MUL_MAT_ID)
|
||||
|
@ -261,10 +254,8 @@ void main() {
|
|||
[[dont_unroll]]
|
||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||
|
||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
// Clamping is expensive, so detect different code paths for each combination
|
||||
// of A and B needing clamping.
|
||||
|
@ -281,16 +272,12 @@ void main() {
|
|||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
|
@ -298,16 +285,12 @@ void main() {
|
|||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
||||
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue