diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 36479b217..f6289822e 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -3,8 +3,9 @@ #include -#define FATTN_KQ_STRIDE 256 -#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) @@ -338,10 +339,16 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, val); + KQ2[j*(kqs_padded/2) + k] = val; } KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + const half2 diff = KQ_max[j0/nwarps] - KQ_max_new; + KQ_max_scale[j0/nwarps] = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask; KQ_max[j0/nwarps] = KQ_max_new; half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); @@ -350,8 +357,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; half2 val = KQ2[j*(kqs_padded/2) + k]; - val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - val = h2exp(val - KQ_max[j0/nwarps]); + const half2 diff = val - KQ_max[j0/nwarps]; + val = h2exp(diff); + const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &val) &= ftz_mask; KQ_rowsum_add += val; KQ2[j*(kqs_padded/2) + k] = val; } @@ -501,7 +510,10 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); + const half diff = __low2half(meta[l]) - kqmax; + float KQ_max_scale = hexp(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * __high2float(meta[l]);