fixup! CUDA: faster softmax via shared memory + fp16 math

This commit is contained in:
JohannesGaessler 2024-01-03 17:29:27 +01:00
parent e1936bb52f
commit 44f30434aa

View file

@ -5233,7 +5233,7 @@ template <bool vals_smem, int ncols_template, int block_size_template, bool need
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE);
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
@ -5270,7 +5270,9 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
} else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
}
vals[col_smem] = val;
if (!need_check || col_smem < ncols_smem) {
vals[col_smem] = val;
}
max_val = __hmax2(max_val, val);
}
@ -6730,7 +6732,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem <= g_device_caps[g_main_device].smpb) {
switch (ncols_x) {
@ -6773,7 +6775,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) {