fixup! fixup! CUDA: faster softmax via shared memory + fp16 math
This commit is contained in:
parent
ae26053d1f
commit
e1936bb52f
1 changed files with 2 additions and 2 deletions
|
@ -5247,7 +5247,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
||||||
extern __shared__ half data_soft_max_f16[];
|
extern __shared__ half data_soft_max_f16[];
|
||||||
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
||||||
// (shared memory) buffer to cache values between iterations:
|
// (shared memory) buffer to cache values between iterations:
|
||||||
half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data);
|
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
||||||
|
|
||||||
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
|
@ -5371,7 +5371,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
// shared memory buffer to cache values between iterations:
|
// shared memory buffer to cache values between iterations:
|
||||||
float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols;
|
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue