fixup! CUDA: faster softmax via shared memory + fp16 math
This commit is contained in:
parent
64c46fc6f5
commit
ae26053d1f
1 changed files with 80 additions and 66 deletions
146
ggml-cuda.cu
146
ggml-cuda.cu
|
@ -549,11 +549,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
|
||||||
struct cuda_device_capabilities {
|
struct cuda_device_capabilities {
|
||||||
int cc; // compute capability
|
int cc; // compute capability
|
||||||
|
size_t smpb; // max. shared memory per block
|
||||||
bool vmm; // virtual memory support
|
bool vmm; // virtual memory support
|
||||||
size_t vmm_granularity; // granularity of virtual memory
|
size_t vmm_granularity; // granularity of virtual memory
|
||||||
};
|
};
|
||||||
|
|
||||||
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
||||||
|
|
||||||
static void * g_scratch_buffer = nullptr;
|
static void * g_scratch_buffer = nullptr;
|
||||||
static size_t g_scratch_size = 0; // disabled by default
|
static size_t g_scratch_size = 0; // disabled by default
|
||||||
|
@ -5228,7 +5229,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
||||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int ncols_template, int block_size_template, bool need_check>
|
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
@ -5243,9 +5244,10 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
extern __shared__ half2 data_soft_max_f16[];
|
extern __shared__ half data_soft_max_f16[];
|
||||||
half2 * vals = data_soft_max_f16 + 0; // shared memory buffer to cache values between iterations
|
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
||||||
half * buf_iw = (half *) (data_soft_max_f16 + ncols_smem); // shared memory buffer for inter-warp communication
|
// (shared memory) buffer to cache values between iterations:
|
||||||
|
half2 * vals = vals_smem ? (half2 *) (buf_iw + block_size) : (half2 *) (dst + rowx*ncols_data);
|
||||||
|
|
||||||
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
|
@ -5353,7 +5355,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
|
@ -5367,8 +5369,9 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * vals = data_soft_max_f32 + 0; // shared memory buffer to cache values between iterations
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
float * buf_iw = data_soft_max_f32 + ncols; // shared memory buffer for inter-warp communication
|
// shared memory buffer to cache values between iterations:
|
||||||
|
float * vals = vals_smem ? buf_iw + block_size : dst + rowx*ncols;
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
@ -6727,36 +6730,41 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
||||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
|
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
switch (ncols_x) {
|
if (shmem <= g_device_caps[g_main_device].smpb) {
|
||||||
case 32:
|
switch (ncols_x) {
|
||||||
soft_max_f16<32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 32:
|
||||||
break;
|
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 64:
|
break;
|
||||||
soft_max_f16<64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 64:
|
||||||
break;
|
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 128:
|
break;
|
||||||
soft_max_f16<128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 128:
|
||||||
break;
|
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 256:
|
break;
|
||||||
soft_max_f16<256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 256:
|
||||||
break;
|
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 512:
|
break;
|
||||||
soft_max_f16<512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 512:
|
||||||
break;
|
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 1024:
|
break;
|
||||||
soft_max_f16<1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 1024:
|
||||||
break;
|
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 2048:
|
break;
|
||||||
soft_max_f16<2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 2048:
|
||||||
break;
|
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 4096:
|
break;
|
||||||
soft_max_f16<4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 4096:
|
||||||
break;
|
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
default:
|
break;
|
||||||
soft_max_f16<0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
default:
|
||||||
break;
|
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
||||||
|
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6765,36 +6773,41 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
const int64_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
|
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
switch (ncols_x) {
|
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||||
case 32:
|
switch (ncols_x) {
|
||||||
soft_max_f32<32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 32:
|
||||||
break;
|
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 64:
|
break;
|
||||||
soft_max_f32<64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 64:
|
||||||
break;
|
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 128:
|
break;
|
||||||
soft_max_f32<128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 128:
|
||||||
break;
|
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 256:
|
break;
|
||||||
soft_max_f32<256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 256:
|
||||||
break;
|
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 512:
|
break;
|
||||||
soft_max_f32<512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 512:
|
||||||
break;
|
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 1024:
|
break;
|
||||||
soft_max_f32<1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 1024:
|
||||||
break;
|
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 2048:
|
break;
|
||||||
soft_max_f32<2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 2048:
|
||||||
break;
|
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
case 4096:
|
break;
|
||||||
soft_max_f32<4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
case 4096:
|
||||||
break;
|
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
default:
|
break;
|
||||||
soft_max_f32<0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
default:
|
||||||
break;
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||||
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7110,6 +7123,7 @@ void ggml_init_cublas() {
|
||||||
#else
|
#else
|
||||||
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
||||||
}
|
}
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
g_tensor_split[id] /= total_vram;
|
g_tensor_split[id] /= total_vram;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue