fixup! fixup! CUDA: faster softmax via shared memory + fp16 math

This commit is contained in:
JohannesGaessler 2024-01-03 16:56:27 +01:00
parent ae26053d1f
commit e1936bb52f

View file

@ -5247,7 +5247,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
extern __shared__ half data_soft_max_f16[]; extern __shared__ half data_soft_max_f16[];
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
// (shared memory) buffer to cache values between iterations: // (shared memory) buffer to cache values between iterations:
half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data); half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
half2 max_val = make_half2(-INFINITY, -INFINITY); half2 max_val = make_half2(-INFINITY, -INFINITY);
@ -5371,7 +5371,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations: // shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols; float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
float max_val = -INFINITY; float max_val = -INFINITY;