fixup! CUDA: faster softmax via shared memory + fp16 math
This commit is contained in:
parent
64c46fc6f5
commit
ae26053d1f
1 changed files with 80 additions and 66 deletions
146
ggml-cuda.cu
146
ggml-cuda.cu
|
@ -549,11 +549,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
|||
|
||||
struct cuda_device_capabilities {
|
||||
int cc; // compute capability
|
||||
size_t smpb; // max. shared memory per block
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
};
|
||||
|
||||
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
||||
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
||||
|
||||
static void * g_scratch_buffer = nullptr;
|
||||
static size_t g_scratch_size = 0; // disabled by default
|
||||
|
@ -5228,7 +5229,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||
}
|
||||
|
||||
template <int ncols_template, int block_size_template, bool need_check>
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
@ -5243,9 +5244,10 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
extern __shared__ half2 data_soft_max_f16[];
|
||||
half2 * vals = data_soft_max_f16 + 0; // shared memory buffer to cache values between iterations
|
||||
half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
|
||||
extern __shared__ half data_soft_max_f16[];
|
||||
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
||||
// (shared memory) buffer to cache values between iterations:
|
||||
half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data);
|
||||
|
||||
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
||||
|
||||
|
@ -5353,7 +5355,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
template <int ncols_template, int block_size_template>
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
|
@ -5367,8 +5369,9 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * vals = data_soft_max_f32 + 0; // shared memory buffer to cache values between iterations
|
||||
float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
|
||||
|
@ -6727,36 +6730,41 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
|||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
|
||||
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f16<32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f16<64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f16<128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f16<256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f16<512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f16<1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f16<2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f16<4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
default:
|
||||
soft_max_f16<0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
if (shmem <= g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
default:
|
||||
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
||||
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6765,36 +6773,41 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
|||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
|
||||
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7110,6 +7123,7 @@ void ggml_init_cublas() {
|
|||
#else
|
||||
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
||||
}
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
g_tensor_split[id] /= total_vram;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue