diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index dda344531..4cf2907e8 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16( if (parallel_blocks == 1 || tid != 0) { return; } - dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, - float * __restrict__ dst, - half2 * __restrict__ dst_meta, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, const float scale, const int ne00, const int ne01, @@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16( continue; } - half2 dst_meta_val; + float2 dst_meta_val; if (std::is_same::value) { - reinterpret_cast(dst_meta_val.x) = KQ_max_f[j0/nwarps]; + dst_meta_val.x = KQ_max_f[j0/nwarps]; } else { - dst_meta_val = KQ_max_h2[j0/nwarps]; + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } - reinterpret_cast(dst_meta_val.y) = KQ_rowsum_j; + dst_meta_val.y = KQ_rowsum_j; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; } #else @@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const half2 * __restrict__ VKQ_meta, + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { #if FP16_AVAILABLE VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; @@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results( const int tid = threadIdx.x; __builtin_assume(tid < D); - __shared__ half2 meta[parallel_blocks]; - if (tid < parallel_blocks) { - meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); - half kqmax = __low2half(meta[0]); + float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { - kqmax = __hmax(kqmax, __low2half(meta[l])); + kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { - const half diff = __low2half(meta[l]) - kqmax; - float KQ_max_scale = hexp(diff); - const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; - VKQ_denominator += KQ_max_scale * __high2float(meta[l]); + VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; @@ -643,8 +643,8 @@ template void launch_fattn_vec_f16( const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, ggml_cuda_pool & pool, cudaStream_t main_stream ) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -694,8 +694,8 @@ template dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));