diff --git a/examples/llava/MobileVLM-README.md b/examples/llava/MobileVLM-README.md index c6258eba6..9eba791da 100644 --- a/examples/llava/MobileVLM-README.md +++ b/examples/llava/MobileVLM-README.md @@ -111,17 +111,71 @@ llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 m llama_print_timings: total time = 34570.79 ms ``` +## Orin compile and run +### compile +```sh +make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_87 LLAMA_CUDA_F16=1 -j 32 +``` + +### run on Orin +### case 1 +**input** +```sh +./llava-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + --image /data/local/tmp/demo.jpeg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \ + --n-gpu-layers 999 +``` +**output** +```sh + +encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch) + + Susan Wise Bauer + +llama_print_timings: load time = 1067.64 ms +llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second) +llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second) +llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second) +llama_print_timings: total time = 1352.63 ms / 252 tokens +``` + +### case 2 +**input** +```sh +./llava-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \ + --n-gpu-layers 999 + +``` +**output** +```sh +encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch) + + The image features a cat lying in the grass. + +llama_print_timings: load time = 1057.07 ms +llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second) +llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second) +llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second) +llama_print_timings: total time = 1365.47 ms / 243 tokens +``` + ## Minor shortcomings The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost. ## TODO -- [ ] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid` +- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid` - [ ] Optimize LDP projector performance - Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`; - Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc. -- [ ] run MobileVLM on `Jetson Orin` +- [x] run MobileVLM on `Jetson Orin` - [ ] Support more model variants, such as `MobileVLM-3B`. diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7f460449e..b211b1a8a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -512,6 +512,8 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_HARDSIGMOID_BLOCK_SIZE 256 +#define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 @@ -811,6 +813,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } +static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +static __global__ void hardswish_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); +} + static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -5656,12 +5676,14 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, } static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.y; + const int row = blockIdx.x; const int col = threadIdx.x; float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { + int i = col; + while(i < ncols) { sum += x[row * ncols + i]; + i += blockDim.x; } sum = warp_reduce_sum(sum); @@ -5978,9 +6000,9 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -static __global__ void im2col_f32_f16( - const float * x, half * dst, - int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW, +static __global__ void im2col_f32_f32( + const float * x, float * dst, int batch_offset, + int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= pelements) { @@ -5993,17 +6015,55 @@ static __global__ void im2col_f32_f16( const int ky = (i - kd) / OW; const int ix = i % OW; + const int oh = blockIdx.y; + const int batch = blockIdx.z / IC; + const int ic = blockIdx.z % IC; + const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = blockIdx.y * s1 + ky * d1 - p1; + const int64_t iih = oh * s1 + ky * d1 - p1; const int64_t offset_dst = - (blockIdx.y * OW + ix) * CHW + - (blockIdx.z * (KW * KH) + ky * KW + kx); + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = (0.0f); + } else { + const int64_t offset_src = ic * offset_delta + batch * batch_offset; + dst[offset_dst] = (x[offset_src + iih * IW + iiw]); + } +} + +static __global__ void im2col_f32_f16( + const float * x, half * dst, int batch_offset, + int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= pelements) { + return; + } + + const int ksize = OW * (KH > 1 ? KW : 1); + const int kx = i / ksize; + const int kd = kx * ksize; + const int ky = (i - kd) / OW; + const int ix = i % OW; + + const int oh = blockIdx.y; + const int batch = blockIdx.z / IC; + const int ic = blockIdx.z % IC; + + const int64_t iiw = ix * s0 + kx * d0 - p0; + const int64_t iih = oh * s1 + ky * d1 - p1; + + const int64_t offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst[offset_dst] = __float2half(0.0f); } else { - const int64_t offset_src = blockIdx.z * offset_delta; + const int64_t offset_src = ic * offset_delta + batch * batch_offset; dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); } } @@ -6221,6 +6281,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ relu_f32<<>>(x, dst, k); } +static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; + hardsigmoid_f32<<>>(x, dst, k); +} + +static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE; + hardswish_f32<<>>(x, dst, k); +} + static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; leaky_relu_f32<<>>(x, dst, k, negative_slope); @@ -7276,7 +7346,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); } @@ -7388,14 +7458,24 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con } } -static void im2col_f32_f16_cuda(const float* x, half* dst, +static void im2col_f32_f32_cuda(const float* x, float* dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, - int offset_delta, + int batch, int batch_offset, int offset_delta, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, IC); - im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + dim3 block_nums(num_blocks, OH, batch * IC); + im2col_f32_f32<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +} + +static void im2col_f32_f16_cuda(const float* x, half* dst, + int IW, int IH, int OW, int OH, int KW, int KH, int IC, + int batch, int batch_offset, int offset_delta, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + const int parallel_elements = OW * KW * KH; + const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OH, batch * IC); + im2col_f32_f16<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } // buffer pool for cuda @@ -7980,6 +8060,34 @@ static void ggml_cuda_op_relu( (void) src1_dd; } +static void ggml_cuda_op_hardsigmoid( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +static void ggml_cuda_op_hardswish( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_leaky_relu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { @@ -8612,7 +8720,7 @@ static void ggml_cuda_op_im2col( GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; @@ -8634,8 +8742,13 @@ static void ggml_cuda_op_im2col( const int64_t OW = dst->ne[1]; const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[3]; + const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 - im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + if(dst->type == GGML_TYPE_F16) + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + else + im2col_f32_f32_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); (void) src0; (void) src0_dd; @@ -9231,6 +9344,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); } +static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid); +} + +static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish); +} static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu); } @@ -10109,6 +10229,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_UNARY_OP_RELU: func = ggml_cuda_relu; break; + case GGML_UNARY_OP_HARDSIGMOID: + func = ggml_cuda_hardsigmoid; + break; + case GGML_UNARY_OP_HARDSWISH: + func = ggml_cuda_hardswish; + break; default: return false; } @@ -10917,6 +11043,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: return true; diff --git a/ggml.c b/ggml.c index ca98fde8a..1c74d80e3 100644 --- a/ggml.c +++ b/ggml.c @@ -5296,7 +5296,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d( int s0, int p0, int d0) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -5374,16 +5374,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d( int p1, int d0, int d1) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW] - - struct ggml_tensor * result = - ggml_mul_mat(ctx, - ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] return result; @@ -5404,7 +5403,8 @@ struct ggml_tensor * ggml_im2col( int p1, int d0, int d1, - bool is_2D) { + bool is_2D, + enum ggml_type dst_type) { if(is_2D) { GGML_ASSERT(a->ne[2] == b->ne[2]); @@ -5428,7 +5428,7 @@ struct ggml_tensor * ggml_im2col( is_2D ? b->ne[3] : 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); @@ -5453,7 +5453,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -5579,12 +5579,28 @@ struct ggml_tensor * ggml_pool_2d( is_node = true; } + struct ggml_tensor * result; +#if defined(GGML_USE_CUBLAS) + if(!(op == GGML_OP_POOL_AVG)) { + GGML_ASSERT(false); + } + + const int64_t ne[4] = {k0, k1, 1, a->ne[2]}; + struct ggml_tensor * b = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct ggml_tensor * im2col = ggml_im2col(ctx, b, new_a, + s0, s1, p0, p1, 1, 1, true, GGML_TYPE_F32); // [N * IC, OH, OW, KH * KW] + + result = ggml_sum_rows(ctx, im2col); + result = ggml_scale(ctx, result, 1. / (k0 * k1)); + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[2], a->ne[3]); +#else const int64_t ne[3] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), a->ne[2], }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); @@ -5592,7 +5608,7 @@ struct ggml_tensor * ggml_pool_2d( result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - +#endif return result; } @@ -12416,6 +12432,92 @@ static void ggml_compute_forward_conv_transpose_1d( } } +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + + // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -12506,14 +12608,14 @@ static void ggml_compute_forward_im2col( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - switch (src0->type) { + switch (dst->type) { case GGML_TYPE_F16: { ggml_compute_forward_im2col_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - GGML_ASSERT(false); + ggml_compute_forward_im2col_f32(params, src0, src1, dst); } break; default: { diff --git a/ggml.h b/ggml.h index 1c4976271..0f541b2e7 100644 --- a/ggml.h +++ b/ggml.h @@ -1493,7 +1493,8 @@ extern "C" { int p1, int d0, int d1, - bool is_2D); + bool is_2D, + enum ggml_type dst_type); GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( struct ggml_context * ctx, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d..ac1ae8ad2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1174,7 +1174,7 @@ struct test_im2col : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D); + ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F16); return out; } };