Fix CUDA softmax by subtracting max value before exp

This commit is contained in:
lijiahao 2023-08-19 11:55:01 +08:00
parent 1f0bccb279
commit 12e4284c31

View file

@ -3955,24 +3955,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int block_size = blockDim.x;
const int tid = threadIdx.x;
float tmp = 0.0;
for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;
if (col >= ncols) {
break;
}
float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i]);
max_val = max(max_val, x[i]);
}
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
}
float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
tmp += val;
dst[i] = val;
}
@ -3983,15 +3988,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;
if (col >= ncols) {
break;
}
const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
dst[i] /= tmp;
dst[i] *= inv_tmp;
}
}