fix KQ FP32 precision fpr parallel_blocks > 1
This commit is contained in:
parent
2f538b9547
commit
87968de9a9
1 changed files with 24 additions and 24 deletions
|
@ -16,7 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
half2 * __restrict__ dst_meta,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
|
@ -180,7 +180,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
if (parallel_blocks == 1 || tid != 0) {
|
||||
return;
|
||||
}
|
||||
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_half2(kqmax, kqsum);
|
||||
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
|
@ -195,7 +195,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
half2 * __restrict__ dst_meta,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
|
@ -555,13 +555,13 @@ static __global__ void flash_attn_ext_f16(
|
|||
continue;
|
||||
}
|
||||
|
||||
half2 dst_meta_val;
|
||||
float2 dst_meta_val;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
reinterpret_cast<half&>(dst_meta_val.x) = KQ_max_f[j0/nwarps];
|
||||
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
||||
} else {
|
||||
dst_meta_val = KQ_max_h2[j0/nwarps];
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
reinterpret_cast<half&>(dst_meta_val.y) = KQ_rowsum_j;
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
|
@ -573,7 +573,7 @@ template<int D, int parallel_blocks> // D == head size
|
|||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const half2 * __restrict__ VKQ_meta,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst) {
|
||||
#if FP16_AVAILABLE
|
||||
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||||
|
@ -583,30 +583,30 @@ static __global__ void flash_attn_combine_results(
|
|||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ half2 meta[parallel_blocks];
|
||||
if (tid < parallel_blocks) {
|
||||
meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid];
|
||||
__shared__ float2 meta[parallel_blocks];
|
||||
if (tid < 2*parallel_blocks) {
|
||||
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
half kqmax = __low2half(meta[0]);
|
||||
float kqmax = meta[0].x;
|
||||
#pragma unroll
|
||||
for (int l = 1; l < parallel_blocks; ++l) {
|
||||
kqmax = __hmax(kqmax, __low2half(meta[l]));
|
||||
kqmax = max(kqmax, meta[l].x);
|
||||
}
|
||||
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const half diff = __low2half(meta[l]) - kqmax;
|
||||
float KQ_max_scale = hexp(diff);
|
||||
const uint ftz_mask = 0xFFFFFFFF * (diff > __float2half(SOFTMAX_FTZ_THRESHOLD));
|
||||
const float diff = meta[l].x - kqmax;
|
||||
const float KQ_max_scale = expf(diff);
|
||||
const uint ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * __high2float(meta[l]);
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
|
@ -644,7 +644,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
@ -695,7 +695,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
|
|||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<half2> dst_tmp_meta(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue