From fde13b3bb9f569f07fd8af74696ee48a43d05131 Mon Sep 17 00:00:00 2001 From: John Balis Date: Tue, 2 Jul 2024 11:09:52 -0500 Subject: [PATCH] feat: cuda implementation for `ggml_conv_transpose_1d` (ggml/854) * conv transpose 1d passing test for 1d input and kernel * working for different input and output channel counts, added test for variable stride * initial draft appears to work with stride other than 1 * working with all old and new conv1d tests * added a test for large tensors * removed use cuda hardcoding * restored test-conv-transpose.c * removed unused arugments, and fixed bug where test failure would cause subsequent tests to fail * fixed accumulator bug * added test to test-backend-ops * fixed mistake * addressed review * fixed includes * removed blank lines * style and warning fixes * return failure when test fails * fix supports_op --------- Co-authored-by: slaren --- ggml/src/ggml-cuda.cu | 13 ++++ ggml/src/ggml-cuda/conv-transpose-1d.cu | 87 ++++++++++++++++++++++++ ggml/src/ggml-cuda/conv-transpose-1d.cuh | 5 ++ tests/test-backend-ops.cpp | 42 +++++++++++- 4 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-cuda/conv-transpose-1d.cu create mode 100644 ggml/src/ggml-cuda/conv-transpose-1d.cuh diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1c9ccc8a1..ed784ea1c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -29,6 +29,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" #include #include @@ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_cuda_op_conv_transpose_1d(ctx,dst); + break; case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; @@ -2804,6 +2808,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + return false; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu new file mode 100644 index 000000000..b1e94d6f7 --- /dev/null +++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -0,0 +1,87 @@ +#include "conv-transpose-1d.cuh" + +static __global__ void conv_transpose_1d_kernel( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst) { + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + if (global_index >= output_size) { + return; + } + + int out_index = global_index / dst_ne0; + + float accumulator = 0; + + for (int c = 0; c < src0_ne2; c++) { + int idx = global_index % dst_ne0; + + int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); + int input_offset = src1_ne0 * c; + + for (int i = 0; i < src1_ne0; i++) { + if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { + continue; + } + int weight_idx = idx - i*s0; + + float kernel_weight = src0[kernel_offset + weight_idx]; + float input_value = src1[input_offset+i]; + + accumulator += kernel_weight * input_value; + } + } + dst[global_index] = accumulator; +} + +static void conv_transpose_1d_f32_f32_cuda( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, + cudaStream_t stream) { + + const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; + conv_transpose_1d_kernel<<>>( + s0,p0,d0,output_size, + src0_ne0, src0_ne1, src0_ne2, src0_ne3, + src1_ne0, src1_ne1, src1_ne2, src1_ne3, + dst_ne0, dst_ne1, dst_ne2, dst_ne3, + src0,src1, dst); +} + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + const int p0 = 0;//opts[3]; + const int d0 = 1;//opts[4]; + + const int64_t kernel_size = ggml_nelements(src0); + const int64_t input_size = ggml_nelements(src1); + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); +} diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cuh b/ggml/src/ggml-cuda/conv-transpose-1d.cuh new file mode 100644 index 000000000..6c2cf666b --- /dev/null +++ b/ggml/src/ggml-cuda/conv-transpose-1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2bb71ac03..83d4a0eb1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1266,6 +1266,36 @@ struct test_pool2d : public test_case { } }; +// GGML_OP_CONV_TRANSPOSE_1D +struct test_conv_transpose_1d : public test_case { + + const std::array ne_input; + const std::array ne_kernel; + + // stride + const int s0; + // padding + const int p0; + // dilation + const int d0; + + std::string vars() override { + return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0); + } + + test_conv_transpose_1d(std::array ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1] + std::array ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1] + int s0 = 1, int p0 = 0, int d0 = 1) + : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); + ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0); + return out; + } +}; + // GGML_OP_IM2COL struct test_im2col : public test_case { const ggml_type type_input; @@ -1279,7 +1309,7 @@ struct test_im2col : public test_case { // padding const int p0; const int p1; - // dilatation + // dilation const int d0; const int d1; // mode @@ -2098,6 +2128,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + test_cases.emplace_back(new test_conv_transpose_1d()); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); + test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); + + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));