From 9fdf8ad826bfc8397f11cb7fe31701dd63990874 Mon Sep 17 00:00:00 2001 From: lihan <1091770049@qq.com> Date: Wed, 11 Dec 2024 09:38:33 +0800 Subject: [PATCH] add constexpr and static assert --- ggml/src/ggml-cuda/concat.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index caf872129..2f42b8a95 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -124,6 +124,8 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb1, uint64_t nb2, uint64_t nb3){ + static_assert(dim >= 0 && dim <= 3); + const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; @@ -134,13 +136,13 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - if /*constexpr*/ (dim == 0) { + if constexpr (dim == 0) { x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); - } else if (dim == 1) { + } else if constexpr (dim == 1) { x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); - } else if (dim == 2) { + } else if constexpr (dim == 2) { x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); - } else if (dim == 3) { + } else if constexpr (dim == 3) { x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); } }