CUDA: faster softmax via shared memory + fp16 math

This commit is contained in:
JohannesGaessler 2024-01-02 16:50:43 +01:00
parent 540938f890
commit 64c46fc6f5
2 changed files with 288 additions and 27 deletions

View file

@ -116,6 +116,7 @@
#include "ggml.h"
#include "ggml-backend-impl.h"
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
@ -585,6 +586,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
return a;
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) a;
bad_arch();
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -593,6 +607,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
(void) x;
bad_arch();
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
@ -5201,75 +5228,227 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
template <int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE);
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.x;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
extern __shared__ half2 data_soft_max_f16[];
half2 * vals = data_soft_max_f16 + 0; // shared memory buffer to cache values between iterations
half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
float max_val = -INFINITY;
half2 max_val = make_half2(-INFINITY, -INFINITY);
for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_smem = col0 + tid;
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
const int ix = rowx*ncols_data + col_data;
const int iy = rowy*ncols_data + col_data;
half2 val;
if (need_check && col_data + 0 >= ncols_data) {
val.x = -INFINITY;
} else {
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
}
if (need_check && col_data + WARP_SIZE >= ncols_data) {
val.y = -INFINITY;
} else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
}
vals[col_smem] = val;
max_val = __hmax2(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = max_val;
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
}
__syncthreads();
max_val = buf[lane_id];
max_val = __half2half2(buf_iw[lane_id]);
max_val = warp_reduce_max(max_val);
} else {
max_val = __half2half2(__hmax(max_val.x, max_val.y));
}
float tmp = 0.f;
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_smem = col0 + tid;
if (ncols_template == 0 && col_smem >= ncols_smem) {
break;
}
const half2 val = h2exp(vals[col_smem] - max_val);
for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[ix] = val;
vals[col_smem] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = 0.f;
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = tmp;
buf_iw[warp_id] = tmp.x + tmp.y;
}
__syncthreads();
tmp = buf[lane_id];
tmp = __half2half2(buf_iw[lane_id]);
tmp = warp_reduce_sum(tmp);
} else {
tmp = __half2half2(tmp.x + tmp.y);
}
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
const int col_smem = col0 + tid;
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
const int idst = rowx*ncols_data + col_data;
const half2 result = vals[col_smem] * inv_sum;
if (need_check && col_data + 0 >= ncols_data) {
return;
}
dst[idst] = result.x;
if (need_check && col_data + WARP_SIZE >= ncols_data) {
return;
}
dst[idst + WARP_SIZE] = result.y;
}
#else
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
template <int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
extern __shared__ float data_soft_max_f32[];
float * vals = data_soft_max_f32 + 0; // shared memory buffer to cache values between iterations
float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
float max_val = -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float inv_tmp = 1.f / tmp;
const float inv_sum = 1.0f / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
return;
}
const int idst = rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
}
}
@ -6543,12 +6722,80 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
switch (ncols_x) {
case 32:
soft_max_f16<32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 64:
soft_max_f16<64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 128:
soft_max_f16<128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 256:
soft_max_f16<256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 512:
soft_max_f16<512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 1024:
soft_max_f16<1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 2048:
soft_max_f16<2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 4096:
soft_max_f16<4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
default:
soft_max_f16<0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
}
}
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
switch (ncols_x) {
case 32:
soft_max_f32<32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 64:
soft_max_f32<64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 128:
soft_max_f32<128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 256:
soft_max_f32<256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 512:
soft_max_f32<512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 1024:
soft_max_f32<1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 2048:
soft_max_f32<2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
case 4096:
soft_max_f32<4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
default:
soft_max_f32<0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
break;
}
}
static void im2col_f32_f16_cuda(const float* x, half* dst,
@ -7873,7 +8120,21 @@ static void ggml_cuda_op_soft_max(
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
const bool use_f16_soft_max = false;
#else
#ifdef GGML_CUDA_F16
const bool use_f16_soft_max = true;
#else
const bool use_f16_soft_max = false;
#endif // GGML_CUDA_F16
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
} else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
}
(void) dst;
}

View file

@ -436,7 +436,7 @@ struct test_case {
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
printf("[%s] NMSE = %.3E > %.3E ", ggml_op_desc(t1), err, ud->max_err);
//for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//}