cuda : restore correct im2col

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-12-13 18:40:44 +02:00
parent 694449f0e8
commit a13ba03849
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -5246,19 +5246,30 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int ofs0, int ofs1, int IW, int IH, int CHW,
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
return;
}
const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int iiw = ix * s0 + kx * d0 - p0;
const int iih = blockIdx.y * s1 + ky * d1 - p1;
const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
(blockIdx.y * OW + ix) * CHW +
(blockIdx.z * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
const int offset_src = blockIdx.z * offset_delta;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
}
}
@ -6491,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
int OH, int IW, int IH, int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
static void im2col_f32_f16_cuda(const float* x, half* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, IC);
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
@ -7577,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
@ -7588,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
inline void ggml_cuda_op_sum_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {