diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 2077da53d..aaaea2f07 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -652,7 +652,7 @@ template void launch_fattn_vec_f16( } constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - constexpr dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -680,9 +680,9 @@ template void launch_fattn_vec_f16( return; } - constexpr dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>> @@ -703,7 +703,7 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; @@ -731,9 +731,9 @@ template ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; flash_attn_combine_results <<>>