fix batch size 2-8
This commit is contained in:
parent
617f129e43
commit
d9bcb92f75
1 changed files with 7 additions and 6 deletions
|
@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
kqsum[j] = kqsum_shared[j][threadIdx.x];
|
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||||
|
|
||||||
half dst_val = (__low2half(VKQ[j]) + __high2half(VKQ[j]));
|
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||||
if (parallel_blocks == 1) {
|
if (parallel_blocks == 1) {
|
||||||
dst_val /= kqsum[j];
|
dst_val /= kqsum[j_VKQ];
|
||||||
}
|
}
|
||||||
dst[D*gridDim.y*(blockIdx.x*ncols + j) + D*blockIdx.y + tid] = dst_val;
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks == 1 || tid != 0) {
|
if (parallel_blocks == 1 || tid != 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue