CUDA: faster large batch FA without tensor cores (#7314)
This commit is contained in:
parent
82ca83db3c
commit
0fc1e820a9
7 changed files with 823 additions and 15 deletions
|
@ -56,7 +56,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = blockIdx.y;
|
||||
const uint32_t h = blockIdx.y;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
@ -221,11 +221,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (parallel_blocks != 1 && tid != 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
||||
}
|
||||
if (parallel_blocks != 1 && threadIdx.x < ncols) {
|
||||
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue