vulkan: fix coopmat2 flash attention for non-contiguous inputs (#11281)
Add code similar to mul_mm_cm2 to force alignment of strides, to avoid a performance regression. Add noncontiguous FA tests in test-backend-ops. Fixes #11268.
This commit is contained in:
parent
3edfa7d375
commit
44e18ef939
3 changed files with 82 additions and 12 deletions
|
@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
|
|||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
@ -146,6 +149,23 @@ void main() {
|
|||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements
|
||||
uint32_t q_stride = p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
q_stride &= ~7;
|
||||
#if !defined(BLOCK_SIZE)
|
||||
k_stride &= ~7;
|
||||
v_stride &= ~7;
|
||||
#endif
|
||||
}
|
||||
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue