cuda : "constexpr dim3" -> "const dim3"

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-22 20:31:23 +03:00
parent 5408d55506
commit c70bfd7bcb
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -652,7 +652,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
}
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
const int shmem = 0;
@ -680,9 +680,9 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
return;
}
constexpr dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
@ -703,7 +703,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
}
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
constexpr dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
@ -731,9 +731,9 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
return;
}
constexpr dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>