From 91c10ef225a56505bf17885d301003be349be3e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sun, 28 Apr 2024 17:07:41 -0400 Subject: [PATCH] Fix some more int overflow in softmax. --- ggml-cuda/softmax.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e58..fa8f987cf 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; float max_val = -INFINITY; @@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; + const int64_t ix = (int64_t)rowx*ncols + col; + const int64_t iy = (int64_t)rowy*ncols + col; const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); @@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f return; } - const int idst = rowx*ncols + col; + const int64_t idst = (int64_t)rowx*ncols + col; dst[idst] = vals[col] * inv_sum; } }