code clean

This commit is contained in:
zhangjidong 2024-01-26 15:38:37 +08:00
parent b08c6b1ad8
commit c29a855453

View file

@ -5680,10 +5680,8 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
const int col = threadIdx.x;
float sum = 0.0f;
int i = col;
while(i < ncols) {
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
i += blockDim.x;
}
sum = warp_reduce_sum(sum);
@ -6000,8 +5998,9 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static __global__ void im2col_f32_f32(
const float * x, float * dst, int batch_offset,
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
@ -6027,44 +6026,10 @@ static __global__ void im2col_f32_f32(
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = (0.0f);
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = (x[offset_src + iih * IW + iiw]);
}
}
static __global__ void im2col_f32_f16(
const float * x, half * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
return;
}
const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int oh = blockIdx.y;
const int batch = blockIdx.z / IC;
const int ic = blockIdx.z % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
@ -7458,24 +7423,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
}
}
static void im2col_f32_f32_cuda(const float* x, float* dst,
template <typename T>
static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_f32_f32<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
static void im2col_f32_f16_cuda(const float* x, half* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
@ -8746,9 +8702,9 @@ static void ggml_cuda_op_im2col(
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16)
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
else
im2col_f32_f32_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;