From ed0891f2a48fbaddbb5a185c8e7d8abf06095afc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 27 May 2024 15:08:37 +0300 Subject: [PATCH] cuda : generalize concat kernel ggml-ci --- ggml-cuda/concat.cu | 90 +++++++++++++++++++++++++++++++++++--- tests/test-backend-ops.cpp | 10 ++--- 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/ggml-cuda/concat.cu b/ggml-cuda/concat.cu index 2941d2f17..caa9c1e04 100644 --- a/ggml-cuda/concat.cu +++ b/ggml-cuda/concat.cu @@ -1,15 +1,68 @@ #include "concat.cuh" -static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) { +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; } - // operation + int offset_dst = nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * (ne0 - ne00) * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (blockIdx.y < ne01) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + (blockIdx.y - ne01) * ne0 + + blockIdx.z * ne0 * (gridDim.y - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (blockIdx.z < ne02) { // src0 int offset_src = nidx + @@ -25,25 +78,50 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst, } } -static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) { +static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2); - concat_f32<<>>(x, y, dst, ne0, ne02); + if (dim == 0) { + concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + return; + } + if (dim == 1) { + concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + return; + } + concat_f32_dim2<<>>(x, y, dst, ne0, ne02); } void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + const int32_t dim = ((int32_t *) dst->op_params)[0]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream); + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a8284a2a..876f7329a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1261,21 +1261,21 @@ struct test_concat : public test_case { const ggml_type type; const std::array ne_a; const int dim; - const int64_t b_ned; + const int64_t ne_b_d; std::string vars() override { - return VARS_TO_STR4(type, ne_a, dim, b_ned); + return VARS_TO_STR4(type, ne_a, dim, ne_b_d); } test_concat(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 10, 10, 10}, int dim = 2, - int64_t b_ned = 10) - : type(type), ne_a(ne_a), dim(dim), b_ned(b_ned) {} + int64_t ne_b_d = 10) + : type(type), ne_a(ne_a), dim(dim), ne_b_d(ne_b_d) {} ggml_tensor * build_graph(ggml_context * ctx) override { auto ne_b = ne_a; - ne_b[dim] = b_ned; + ne_b[dim] = ne_b_d; ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); ggml_tensor * out = ggml_concat(ctx, a, b, dim);