pragma unroll, use_mask template parameter

This commit is contained in:
Johannes Gäßler 2024-03-19 12:00:51 +01:00
parent 58c7f6167c
commit 7fca458615

View file

@ -7235,34 +7235,35 @@ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half>
#endif
// based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
template<int D, int Q, int C, bool use_mask> // D head size, Q queries per block, C cache items per block
__launch_bounds__(8*WARP_SIZE, 1) // tells the compiler to avoid register spilling even if it reduces occupancy
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ dst,
float scale,
int ne00,
int ne01,
int ne02,
int ne03,
int ne10,
int ne11,
int ne12,
int ne13,
int ne31,
int nb31,
int nb01,
int nb02,
int nb03,
int nb11,
int nb12,
int nb13,
int ne0,
int ne1,
int ne2,
int ne3) {
const char * __restrict__ q,
const char * __restrict__ k,
const char * __restrict__ v,
const char * __restrict__ mask,
float * __restrict__ dst,
const float scale,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if __CUDA_ARCH__ >= CC_VOLTA
const int warp_id = threadIdx.y;
const int lane_id = threadIdx.x;
@ -7319,24 +7320,28 @@ static __global__ void flash_attn_ext_f16(
}
}
nvcuda::wmma::fill_fragment(zr, 0.0);
nvcuda::wmma::fill_fragment(zr, 0.0f);
// zero out lo
#pragma unroll
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::fill_fragment(lo[j][i], 0.0f);
}
}
// zero out shared memory SH
#pragma unroll
for (int j = 0; j < Q; ++j) {
#pragma unroll
for (int i0 = 0; i0 < SH; i0 += NW) {
const int i = i0 + lane_id;
if (i >= SH) {
break;
}
ss[j*T + i] = 0.0;
ss[j*T + i] = 0.0f;
}
}
@ -7346,6 +7351,7 @@ static __global__ void flash_attn_ext_f16(
half S = __float2half(0.0f);
half M[Q];
#pragma unroll
for (int i = 0; i < Q; ++i) {
M[i] = CUDART_MIN_DENORM_FP16;
}
@ -7375,17 +7381,20 @@ static __global__ void flash_attn_ext_f16(
// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
#pragma unroll
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
const half * mp = use_mask ? (const half *) (mask + iq1*nb31) : nullptr;
// prepare diagonal scale matrix
half16x16_b mscale;
#pragma unroll
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
@ -7404,27 +7413,31 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int cc = 0; cc < C16; ++cc) {
half16x16_acc mqk[Q16];
#pragma unroll
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
nvcuda::wmma::fill_fragment(mqk[j], 0.0f);
}
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
#pragma unroll
for (int i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
#pragma unroll
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
#pragma unroll
for (int j = 0; j < Q16; ++j) {
half16x16_a mqka;
half16x16_acc mm;
if (mp) {
if (use_mask) {
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
}
@ -7432,7 +7445,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, use_mask ? mm : zr);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
@ -7442,9 +7455,11 @@ static __global__ void flash_attn_ext_f16(
half2 smax = make_half2(-INFINITY, -INFINITY);
// online softmax
#pragma unroll
for (int j = 0; j < Q; ++j) {
const half m = M[j];
#pragma unroll
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
@ -7460,6 +7475,7 @@ static __global__ void flash_attn_ext_f16(
half2 ls = make_half2(0.0f, 0.0f);
half2 M2 = make_half2(M[j], M[j]);
#pragma unroll
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
@ -7493,12 +7509,14 @@ static __global__ void flash_attn_ext_f16(
}
// O = diag(ms)*O
#pragma unroll
for (int j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b lob;
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
#pragma unroll
for (int i = 0; i < D16; ++i) {
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
@ -7509,26 +7527,32 @@ static __global__ void flash_attn_ext_f16(
}
// restore zeros
#pragma unroll
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}
// O = O + (Q*K^T)*V
{
#pragma unroll
for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16];
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a ms[Q16];
#pragma unroll
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
}
#pragma unroll
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
}
@ -7545,12 +7569,15 @@ static __global__ void flash_attn_ext_f16(
}
// reduce the warps sequentially
#pragma unroll
for (int sg = 1; sg < num_warps; ++sg) {
__syncthreads();
// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
#pragma unroll
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
@ -7561,6 +7588,7 @@ static __global__ void flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
#pragma unroll
for (int j = lane_id; j < Q; j += NW) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];
@ -7583,6 +7611,7 @@ static __global__ void flash_attn_ext_f16(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
#pragma unroll
for (int j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
@ -7592,6 +7621,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
@ -7608,7 +7638,9 @@ static __global__ void flash_attn_ext_f16(
// store result to shared memory (reuse sq)
if (warp_id == 0) {
#pragma unroll
for (int j = 0; j < Q16; ++j) {
#pragma unroll
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
@ -7617,9 +7649,11 @@ static __global__ void flash_attn_ext_f16(
// final rescale with 1/S and store to global memory
if (warp_id == 0) {
#pragma unroll
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
#pragma unroll
for (int i0 = 0; i0 < D; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D) {
@ -11927,9 +11961,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
GGML_ASSERT(mask); // FIXME case without mask
switch (Q->ne[0]) {
case 64:
flash_attn_ext_f16<64, NQPB, NCPW>
flash_attn_ext_f16<64, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -11946,7 +11981,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 80:
flash_attn_ext_f16<80, NQPB, NCPW>
flash_attn_ext_f16<80, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -11963,7 +11998,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 96:
flash_attn_ext_f16<96, NQPB, NCPW>
flash_attn_ext_f16<96, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -11980,7 +12015,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 112:
flash_attn_ext_f16<112, NQPB, NCPW>
flash_attn_ext_f16<112, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -11997,7 +12032,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 128:
flash_attn_ext_f16<128, NQPB, NCPW>
flash_attn_ext_f16<128, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -12014,7 +12049,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 256:
flash_attn_ext_f16<256, NQPB, NCPW>
flash_attn_ext_f16<256, NQPB, NCPW, true>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -12031,6 +12066,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
default:
GGML_ASSERT(false);
break;
}