diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b211b1a8a..db136ae9c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5680,10 +5680,8 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc const int col = threadIdx.x; float sum = 0.0f; - int i = col; - while(i < ncols) { + for (int i = col; i < ncols; i += blockDim.x) { sum += x[row * ncols + i]; - i += blockDim.x; } sum = warp_reduce_sum(sum); @@ -6000,8 +5998,9 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -static __global__ void im2col_f32_f32( - const float * x, float * dst, int batch_offset, +template +static __global__ void im2col_kernel( + const float * x, T * dst, int batch_offset, int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -6027,44 +6026,10 @@ static __global__ void im2col_f32_f32( (ic * (KW * KH) + ky * KW + kx); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = (0.0f); + dst[offset_dst] = 0.0f; } else { const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = (x[offset_src + iih * IW + iiw]); - } -} - -static __global__ void im2col_f32_f16( - const float * x, half * dst, int batch_offset, - int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, - int s0, int s1, int p0, int p1, int d0, int d1) { - const int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= pelements) { - return; - } - - const int ksize = OW * (KH > 1 ? KW : 1); - const int kx = i / ksize; - const int kd = kx * ksize; - const int ky = (i - kd) / OW; - const int ix = i % OW; - - const int oh = blockIdx.y; - const int batch = blockIdx.z / IC; - const int ic = blockIdx.z % IC; - - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; - - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = __float2half(0.0f); - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + dst[offset_dst] = x[offset_src + iih * IW + iiw]; } } @@ -7458,24 +7423,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con } } -static void im2col_f32_f32_cuda(const float* x, float* dst, +template +static void im2col_cuda(const float* x, T* dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, int batch, int batch_offset, int offset_delta, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; dim3 block_nums(num_blocks, OH, batch * IC); - im2col_f32_f32<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); -} - -static void im2col_f32_f16_cuda(const float* x, half* dst, - int IW, int IH, int OW, int OH, int KW, int KH, int IC, - int batch, int batch_offset, int offset_delta, - int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, batch * IC); - im2col_f32_f16<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } // buffer pool for cuda @@ -8746,9 +8702,9 @@ static void ggml_cuda_op_im2col( const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) - im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); else - im2col_f32_f32_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); (void) src0; (void) src0_dd;