ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
This commit is contained in:
parent
2ddc9bbef1
commit
8ad92dc1ec
7 changed files with 79 additions and 62 deletions
20
ggml-cuda.cu
20
ggml-cuda.cu
|
@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||
|
@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||
if (need_check && col_data + 0 >= ncols_data) {
|
||||
val.x = -INFINITY;
|
||||
} else {
|
||||
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
||||
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
|
||||
}
|
||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||
val.y = -INFINITY;
|
||||
} else {
|
||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
||||
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
|
||||
}
|
||||
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
||||
vals[col_smem] = val;
|
||||
|
@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
|||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
||||
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
}
|
||||
|
@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||
}
|
||||
|
||||
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
|
@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
|||
}
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
|
@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max(
|
|||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
|
@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max(
|
|||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
||||
|
||||
if (use_f16_soft_max) {
|
||||
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue