fixup! CUDA: faster softmax via shared memory + fp16 math

This commit is contained in:
JohannesGaessler 2024-01-03 17:29:27 +01:00
parent e1936bb52f
commit 44f30434aa

View file

@ -5233,7 +5233,7 @@ template <bool vals_smem, int ncols_template, int block_size_template, bool need
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE); const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int rowx = blockIdx.x; const int rowx = blockIdx.x;
@ -5270,7 +5270,9 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
} else { } else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
} }
vals[col_smem] = val; if (!need_check || col_smem < ncols_smem) {
vals[col_smem] = val;
}
max_val = __hmax2(max_val, val); max_val = __hmax2(max_val, val);
} }
@ -6730,7 +6732,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1); const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half); const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem <= g_device_caps[g_main_device].smpb) { if (shmem <= g_device_caps[g_main_device].smpb) {
switch (ncols_x) { switch (ncols_x) {
@ -6773,7 +6775,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1); const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float); const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem < g_device_caps[g_main_device].smpb) { if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) { switch (ncols_x) {