diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0b342772e..8c47ab059 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -549,11 +549,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; struct cuda_device_capabilities { int cc; // compute capability + size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} }; static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default @@ -5228,7 +5229,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -template +template static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; @@ -5243,9 +5244,10 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - extern __shared__ half2 data_soft_max_f16[]; - half2 * vals = data_soft_max_f16 + 0; // shared memory buffer to cache values between iterations - half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication + extern __shared__ half data_soft_max_f16[]; + half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication + // (shared memory) buffer to cache values between iterations: + half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data); half2 max_val = make_half2(-INFINITY, -INFINITY); @@ -5353,7 +5355,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -template +template static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -5367,8 +5369,9 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds const int lane_id = threadIdx.x % WARP_SIZE; extern __shared__ float data_soft_max_f32[]; - float * vals = data_soft_max_f32 + 0; // shared memory buffer to cache values between iterations - float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols; float max_val = -INFINITY; @@ -6727,36 +6730,41 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(half); + const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - switch (ncols_x) { - case 32: - soft_max_f16<32, 32, true><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 64: - soft_max_f16<64, 32, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 128: - soft_max_f16<128, 64, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 256: - soft_max_f16<256, 128, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 512: - soft_max_f16<512, 256, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 1024: - soft_max_f16<1024, 512, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 2048: - soft_max_f16<2048, 1024, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 4096: - soft_max_f16<4096, 1024, false><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - default: - soft_max_f16<0, 0, true><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; + if (shmem <= g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(half); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); } } @@ -6765,36 +6773,41 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(float); + const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - switch (ncols_x) { - case 32: - soft_max_f32<32, 32><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 64: - soft_max_f32<64, 64><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 128: - soft_max_f32<128, 128><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 256: - soft_max_f32<256, 256><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 512: - soft_max_f32<512, 512><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 1024: - soft_max_f32<1024, 1024><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 2048: - soft_max_f32<2048, 1024><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - case 4096: - soft_max_f32<4096, 1024><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; - default: - soft_max_f32<0, 0><<>>(x, y, dst, ncols_x, nrows_y, scale); - break; + if (shmem < g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); } } @@ -7110,6 +7123,7 @@ void ggml_init_cublas() { #else g_device_caps[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_device_caps[id].smpb = prop.sharedMemPerBlock; } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram;