cuda : fix im2col_f32_f16 (ggml/#658)

ggml-ci
This commit is contained in:
leejet 2023-12-19 00:46:10 +08:00 committed by Georgi Gerganov
parent a55876955b
commit f0b2ba2089
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -5273,17 +5273,17 @@ static __global__ void im2col_f32_f16(
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int iiw = ix * s0 + kx * d0 - p0;
const int iih = blockIdx.y * s1 + ky * d1 - p1;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
const int offset_dst =
const int64_t offset_dst =
(blockIdx.y * OW + ix) * CHW +
(blockIdx.z * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int offset_src = blockIdx.z * offset_delta;
const int64_t offset_src = blockIdx.z * offset_delta;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
}
}