Use a lambda to avoid code duplication

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
a3sh 2024-12-11 09:14:30 +08:00 committed by GitHub
parent e4189e3188
commit f8a5b04441
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -188,37 +188,33 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
auto launch_kernel = [&](auto dim) {
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
};
switch (dim) {
case 0:
concat_f32_non_cont<0><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
launch_kernel(std::integral_constant<int, 0>{});
break;
case 1:
concat_f32_non_cont<1><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
launch_kernel(std::integral_constant<int, 1>{});
break;
case 2:
concat_f32_non_cont<2><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
launch_kernel(std::integral_constant<int, 2>{});
break;
case 3:
concat_f32_non_cont<3><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
launch_kernel(std::integral_constant<int, 3>{});
break;
default:
GGML_ABORT("Invalid dim: %d", dim);
break;
}
}
}
}