cuda : unroll some of the loops

This commit is contained in:
Georgi Gerganov 2024-02-03 14:10:01 +02:00
parent 1f8a592482
commit 92472ea22c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16(
half16x16_acc lo[Q16][D16]; half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory // load heads from Q to shared memory
#pragma unroll
for (int j0 = 0; j0 < Q; j0 += num_warps) { for (int j0 = 0; j0 < Q; j0 += num_warps) {
const int j = j0 + warp_id; const int j = j0 + warp_id;
if (j >= Q) { if (j >= Q) {
@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16(
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
#pragma unroll
for (int i0 = 0; i0 < D2; i0 += NW) { for (int i0 = 0; i0 < D2; i0 += NW) {
const int i = i0 + lane_id; const int i = i0 + lane_id;
if (i >= D2) { if (i >= D2) {