diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a3a6c6455..deda4cc70 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,6 +6462,7 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory +#pragma unroll for (int j0 = 0; j0 < Q; j0 += num_warps) { const int j = j0 + warp_id; if (j >= Q) { @@ -6470,6 +6471,7 @@ static __global__ void flash_attn_ext_f16( const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); +#pragma unroll for (int i0 = 0; i0 < D2; i0 += NW) { const int i = i0 + lane_id; if (i >= D2) {