Use a lambda to avoid code duplication
Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
e4189e3188
commit
f8a5b04441
1 changed files with 16 additions and 20 deletions
|
@ -188,37 +188,33 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
}
|
||||
} else {
|
||||
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
auto launch_kernel = [&](auto dim) {
|
||||
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
};
|
||||
switch (dim) {
|
||||
case 0:
|
||||
concat_f32_non_cont<0><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
|
||||
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
||||
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
launch_kernel(std::integral_constant<int, 0>{});
|
||||
break;
|
||||
case 1:
|
||||
concat_f32_non_cont<1><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
|
||||
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
||||
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
launch_kernel(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
concat_f32_non_cont<2><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
|
||||
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
||||
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
launch_kernel(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
concat_f32_non_cont<3><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
|
||||
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
||||
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
launch_kernel(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Invalid dim: %d", dim);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue