fixup! CUDA: faster softmax via shared memory + fp16 math
This commit is contained in:
parent
e1936bb52f
commit
44f30434aa
1 changed files with 6 additions and 4 deletions
10
ggml-cuda.cu
10
ggml-cuda.cu
|
@ -5233,7 +5233,7 @@ template <bool vals_smem, int ncols_template, int block_size_template, bool need
|
|||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
const int ncols_smem = GGML_PAD(ncols_data/2, WARP_SIZE);
|
||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
|
@ -5270,7 +5270,9 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||
} else {
|
||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
||||
}
|
||||
vals[col_smem] = val;
|
||||
if (!need_check || col_smem < ncols_smem) {
|
||||
vals[col_smem] = val;
|
||||
}
|
||||
max_val = __hmax2(max_val, val);
|
||||
}
|
||||
|
||||
|
@ -6730,7 +6732,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
|||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(half);
|
||||
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
if (shmem <= g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
|
@ -6773,7 +6775,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
|||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const size_t shmem = (ncols_x + WARP_SIZE)*sizeof(float);
|
||||
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue