diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 52c5599a4..d2dbf824e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -473,7 +473,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -template +template static __global__ void norm_f32(const float * x, float * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -482,7 +482,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { float2 mean_var = make_float2(0.f, 0.f); - for (int col = tid; col < ncols; col += BLOCK_SIZE) { + for (int col = tid; col < ncols; col += block_size) { const float xi = x[row*ncols + col]; mean_var.x += xi; mean_var.y += xi * xi; @@ -490,7 +490,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { // sum up partial sums mean_var = warp_reduce_sum(mean_var); - if (BLOCK_SIZE > WARP_SIZE) { + if (block_size > WARP_SIZE) { __shared__ float2 s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; @@ -506,7 +506,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { const float var = mean_var.y / ncols - mean * mean; const float inv_std = rsqrtf(var + eps); - for (int col = tid; col < ncols; col += BLOCK_SIZE) { + for (int col = tid; col < ncols; col += block_size) { dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; } } @@ -519,21 +519,21 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { return x; } -template +template static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; float tmp = 0.0f; // partial sum for thread in warp - for (int col = tid; col < ncols; col += BLOCK_SIZE) { + for (int col = tid; col < ncols; col += block_size) { const float xi = x[row*ncols + col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); - if (BLOCK_SIZE > WARP_SIZE) { + if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; @@ -548,7 +548,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); - for (int col = tid; col < ncols; col += BLOCK_SIZE) { + for (int col = tid; col < ncols; col += block_size) { dst[row*ncols + col] = scale * x[row*ncols + col]; } }