cuda : unroll some of the loops
This commit is contained in:
parent
1f8a592482
commit
92472ea22c
1 changed files with 2 additions and 0 deletions
|
@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
half16x16_acc lo[Q16][D16];
|
half16x16_acc lo[Q16][D16];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < Q; j0 += num_warps) {
|
for (int j0 = 0; j0 < Q; j0 += num_warps) {
|
||||||
const int j = j0 + warp_id;
|
const int j = j0 + warp_id;
|
||||||
if (j >= Q) {
|
if (j >= Q) {
|
||||||
|
@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D2; i0 += NW) {
|
for (int i0 = 0; i0 < D2; i0 += NW) {
|
||||||
const int i = i0 + lane_id;
|
const int i = i0 + lane_id;
|
||||||
if (i >= D2) {
|
if (i >= D2) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue