pragma unroll, use_mask template parameter
This commit is contained in:
parent
58c7f6167c
commit
7fca458615
1 changed files with 76 additions and 40 deletions
116
ggml-cuda.cu
116
ggml-cuda.cu
|
@ -7235,34 +7235,35 @@ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half>
|
|||
#endif
|
||||
|
||||
// based on metal version
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||
template<int D, int Q, int C, bool use_mask> // D head size, Q queries per block, C cache items per block
|
||||
__launch_bounds__(8*WARP_SIZE, 1) // tells the compiler to avoid register spilling even if it reduces occupancy
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char* __restrict__ q,
|
||||
const char* __restrict__ k,
|
||||
const char* __restrict__ v,
|
||||
const char* __restrict__ mask,
|
||||
float* __restrict__ dst,
|
||||
float scale,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
int ne31,
|
||||
int nb31,
|
||||
int nb01,
|
||||
int nb02,
|
||||
int nb03,
|
||||
int nb11,
|
||||
int nb12,
|
||||
int nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3) {
|
||||
const char * __restrict__ q,
|
||||
const char * __restrict__ k,
|
||||
const char * __restrict__ v,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
const float scale,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
const int ne03,
|
||||
const int ne10,
|
||||
const int ne11,
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int nb31,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
const int warp_id = threadIdx.y;
|
||||
const int lane_id = threadIdx.x;
|
||||
|
@ -7319,24 +7320,28 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
}
|
||||
|
||||
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||
nvcuda::wmma::fill_fragment(zr, 0.0f);
|
||||
|
||||
// zero out lo
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < SH; i0 += NW) {
|
||||
const int i = i0 + lane_id;
|
||||
if (i >= SH) {
|
||||
break;
|
||||
}
|
||||
|
||||
ss[j*T + i] = 0.0;
|
||||
ss[j*T + i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7346,6 +7351,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
half S = __float2half(0.0f);
|
||||
half M[Q];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Q; ++i) {
|
||||
M[i] = CUDART_MIN_DENORM_FP16;
|
||||
}
|
||||
|
@ -7375,17 +7381,20 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// load the queries from shared memory into local memory
|
||||
half16x16_a mq[Q16][D16];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||
}
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||
const half * mp = use_mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||
|
||||
// prepare diagonal scale matrix
|
||||
half16x16_b mscale;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
ss[i*T + i] = __float2half(scale);
|
||||
}
|
||||
|
@ -7404,27 +7413,31 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
half16x16_acc mqk[Q16];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||
nvcuda::wmma::fill_fragment(mqk[j], 0.0f);
|
||||
}
|
||||
|
||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
half16x16_bT mk; // transposed key
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a mqka;
|
||||
half16x16_acc mm;
|
||||
|
||||
if (mp) {
|
||||
if (use_mask) {
|
||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
|
@ -7432,7 +7445,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
||||
|
||||
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
||||
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, use_mask ? mm : zr);
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
|
@ -7442,9 +7455,11 @@ static __global__ void flash_attn_ext_f16(
|
|||
half2 smax = make_half2(-INFINITY, -INFINITY);
|
||||
|
||||
// online softmax
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
#pragma unroll
|
||||
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
|
@ -7460,6 +7475,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
half2 ls = make_half2(0.0f, 0.0f);
|
||||
half2 M2 = make_half2(M[j], M[j]);
|
||||
|
||||
#pragma unroll
|
||||
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
|
@ -7493,12 +7509,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a mm;
|
||||
half16x16_b lob;
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
// convert accumulator to matrix_b
|
||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
|
@ -7509,26 +7527,32 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// restore zeros
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
#pragma unroll
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mv[D16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
||||
}
|
||||
|
||||
half16x16_a ms[Q16];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||
}
|
||||
|
@ -7545,12 +7569,15 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
#pragma unroll
|
||||
for (int sg = 1; sg < num_warps; ++sg) {
|
||||
__syncthreads();
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (warp_id == sg) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
@ -7561,6 +7588,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int j = lane_id; j < Q; j += NW) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
|
@ -7583,6 +7611,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a ms0;
|
||||
half16x16_a ms1;
|
||||
|
@ -7592,6 +7621,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||
|
@ -7608,7 +7638,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
@ -7617,9 +7649,11 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (warp_id == 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += NW) {
|
||||
const int i = i0 + lane_id;
|
||||
if (i >= D) {
|
||||
|
@ -11927,9 +11961,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
//const size_t shmem_max = 96*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
GGML_ASSERT(mask); // FIXME case without mask
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
flash_attn_ext_f16<64, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -11946,7 +11981,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
case 80:
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
flash_attn_ext_f16<80, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -11963,7 +11998,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
case 96:
|
||||
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||
flash_attn_ext_f16<96, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -11980,7 +12015,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
case 112:
|
||||
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||
flash_attn_ext_f16<112, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -11997,7 +12032,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
case 128:
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
flash_attn_ext_f16<128, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -12014,7 +12049,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
case 256:
|
||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||
flash_attn_ext_f16<256, NQPB, NCPW, true>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
|
@ -12031,6 +12066,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue