ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-31 19:17:16 +02:00
parent 2ddc9bbef1
commit 8ad92dc1ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 79 additions and 62 deletions

View file

@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
}
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
if (need_check && col_data + 0 >= ncols_data) {
val.x = -INFINITY;
} else {
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
}
if (need_check && col_data + WARP_SIZE >= ncols_data) {
val.y = -INFINITY;
} else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
}
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
vals[col_smem] = val;
@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
}
template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
}
}
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max(
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
} else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
}
(void) dst;