fix KQ FP32 precision fpr parallel_blocks > 1

This commit is contained in:
Johannes Gäßler 2024-04-17 17:31:03 +02:00
parent 2f538b9547
commit 87968de9a9

View file

@ -15,8 +15,8 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
float * __restrict__ dst, float * __restrict__ dst,
half2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
const int ne00, const int ne00,
const int ne01, const int ne01,
@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16(
if (parallel_blocks == 1 || tid != 0) { if (parallel_blocks == 1 || tid != 0) {
return; return;
} }
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum); dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
@ -194,8 +194,8 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K, const char * __restrict__ K,
const char * __restrict__ V, const char * __restrict__ V,
const char * __restrict__ mask, const char * __restrict__ mask,
float * __restrict__ dst, float * __restrict__ dst,
half2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
const int ne00, const int ne00,
const int ne01, const int ne01,
@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16(
continue; continue;
} }
half2 dst_meta_val; float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) { if (std::is_same<KQ_acc_t, float>::value) {
reinterpret_cast<half&>(dst_meta_val.x) = KQ_max_f[j0/nwarps]; dst_meta_val.x = KQ_max_f[j0/nwarps];
} else { } else {
dst_meta_val = KQ_max_h2[j0/nwarps]; dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
} }
reinterpret_cast<half&>(dst_meta_val.y) = KQ_rowsum_j; dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
} }
#else #else
@ -572,8 +572,8 @@ static __global__ void flash_attn_ext_f16(
template<int D, int parallel_blocks> // D == head size template<int D, int parallel_blocks> // D == head size
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
static __global__ void flash_attn_combine_results( static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts, const float * __restrict__ VKQ_parts,
const half2 * __restrict__ VKQ_meta, const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst) { float * __restrict__ dst) {
#if FP16_AVAILABLE #if FP16_AVAILABLE
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results(
const int tid = threadIdx.x; const int tid = threadIdx.x;
__builtin_assume(tid < D); __builtin_assume(tid < D);
__shared__ half2 meta[parallel_blocks]; __shared__ float2 meta[parallel_blocks];
if (tid < parallel_blocks) { if (tid < 2*parallel_blocks) {
meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
} }
__syncthreads(); __syncthreads();
half kqmax = __low2half(meta[0]); float kqmax = meta[0].x;
#pragma unroll #pragma unroll
for (int l = 1; l < parallel_blocks; ++l) { for (int l = 1; l < parallel_blocks; ++l) {
kqmax = __hmax(kqmax, __low2half(meta[l])); kqmax = max(kqmax, meta[l].x);
} }
float VKQ_numerator = 0.0f; float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f; float VKQ_denominator = 0.0f;
#pragma unroll #pragma unroll
for (int l = 0; l < parallel_blocks; ++l) { for (int l = 0; l < parallel_blocks; ++l) {
const half diff = __low2half(meta[l]) - kqmax; const float diff = meta[l].x - kqmax;
float KQ_max_scale = hexp(diff); const float KQ_max_scale = expf(diff);
const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD)); const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint *) &KQ_max_scale) &= ftz_mask; *((uint *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_denominator += KQ_max_scale * __high2float(meta[l]); VKQ_denominator += KQ_max_scale * meta[l].y;
} }
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
@ -643,8 +643,8 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream ggml_cuda_pool & pool, cudaStream_t main_stream
) { ) {
ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) { if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@ -694,8 +694,8 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream ggml_cuda_pool & pool, cudaStream_t main_stream
) { ) {
ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) { if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));