cuda : express strides with helper constants

This commit is contained in:
Georgi Gerganov 2024-02-04 11:08:47 +02:00
parent 1846e92a90
commit a647257b47
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6399,10 +6399,16 @@ static __global__ void flash_attn_f32(
} }
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b; // queries, dims, cache per fragment
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT; #define QPF 16
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc; #define DPF 16
#define CPF 16
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, QPF, CPF, DPF, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, QPF, CPF, DPF, half, nvcuda::wmma::row_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, QPF, CPF, DPF, half, nvcuda::wmma::col_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, QPF, CPF, DPF, half> half16x16_acc;
#endif #endif
// based on metal version // based on metal version
@ -6443,9 +6449,9 @@ static __global__ void flash_attn_ext_f16(
const int iq2 = blockIdx.y; const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q; const int iq1 = blockIdx.x * Q;
const int D16 = D/16; const int DPT = D/DPF; // dims per thread
const int Q16 = Q/16; const int QPT = Q/QPF; // queries per thread
const int C16 = C/16; const int CPT = C/CPF; // cache per thread
const int NW = WARP_SIZE; const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half) const int SH = (C + Q); // shared memory per simdgroup in (half)
@ -6463,7 +6469,7 @@ static __global__ void flash_attn_ext_f16(
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
half16x16_acc zr; half16x16_acc zr;
half16x16_acc lo[Q16][D16]; half16x16_acc lo[QPT][DPT];
// load heads from Q to shared memory // load heads from Q to shared memory
#pragma unroll #pragma unroll
@ -6493,8 +6499,8 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(zr, 0.0); nvcuda::wmma::fill_fragment(zr, 0.0);
// zero out lo // zero out lo
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0); nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
} }
} }
@ -6545,10 +6551,10 @@ static __global__ void flash_attn_ext_f16(
const int iv3 = iq3 / rv3; const int iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
half16x16_a mq[Q16][D16]; half16x16_a mq[QPT][DPT];
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync(mq[j][i], sq + QPF*j*T + DPF*i, T);
} }
} }
@ -6557,7 +6563,7 @@ static __global__ void flash_attn_ext_f16(
// prepare diagonal scale matrix // prepare diagonal scale matrix
half16x16_b mscale; half16x16_b mscale;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < QPF; ++i) {
ss[i*T + i] = __float2half(scale); ss[i*T + i] = __float2half(scale);
} }
nvcuda::wmma::load_matrix_sync(mscale, ss, T); nvcuda::wmma::load_matrix_sync(mscale, ss, T);
@ -6573,38 +6579,38 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T // Q*K^T
{ {
#pragma unroll #pragma unroll
for (int cc = 0; cc < C16; ++cc) { for (int cc = 0; cc < CPT; ++cc) {
half16x16_acc mqk[Q16]; half16x16_acc mqk[QPT];
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0); nvcuda::wmma::fill_fragment(mqk[j], 0);
} }
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); const half * pk = (const half *) ((const char *) k + ((ic + CPF*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
half16x16_bT mk; // transposed key half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); nvcuda::wmma::load_matrix_sync(mk, pk + DPF*i, nb11/sizeof(half));
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
} }
} }
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
half16x16_a mqka; half16x16_a mqka;
half16x16_acc mm; half16x16_acc mm;
if (mp) { if (mp) {
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync(mm, mp + QPF*j*(nb31/sizeof(half)) + ic + CPF*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
} }
// convert accumulator to matrix_a // convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); nvcuda::wmma::load_matrix_sync (mqka, ss + QPF*j*T + CPF*cc, T);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
} }
} }
} }
@ -6664,43 +6670,43 @@ static __global__ void flash_attn_ext_f16(
} }
// O = diag(ms)*O // O = diag(ms)*O
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
half16x16_a mm; half16x16_a mm;
half16x16_b lob; half16x16_b lob;
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(mm, ss + QPF*j*T + C + QPF*j, T);
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
// convert accumulator to matrix_b // convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + C + QPF*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync (lob, ss + QPF*j*T + C + QPF*j, T);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
} }
} }
// restore zeros // restore zeros
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + C + QPF*j, zr, T, nvcuda::wmma::mem_row_major);
} }
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
for (int cc = 0; cc < C16; ++cc) { for (int cc = 0; cc < CPT; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); const half * pv = (const half *) ((const char *) v + ((ic + CPF*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16]; half16x16_b mv[DPT];
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); nvcuda::wmma::load_matrix_sync(mv[i], pv + DPF*i, nb21/sizeof(half));
} }
half16x16_a ms[Q16]; half16x16_a ms[QPT];
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); nvcuda::wmma::load_matrix_sync(ms[j], ss + QPF*j*T + CPF*cc, T);
} }
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
} }
} }
@ -6721,9 +6727,9 @@ static __global__ void flash_attn_ext_f16(
// each simdgroup stores its output to shared memory, reusing sq // each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) { if (warp_id == sg) {
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
} }
} }
} }
@ -6754,22 +6760,22 @@ static __global__ void flash_attn_ext_f16(
} }
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
half16x16_a ms0; half16x16_a ms0;
half16x16_a ms1; half16x16_a ms1;
half16x16_b t; half16x16_b t;
half16x16_acc t2; half16x16_acc t2;
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms0, ss + QPF*j*T + C + QPF*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); nvcuda::wmma::load_matrix_sync(ms1, ss + QPF*j*T + C + QPF*j + sg*SH, T);
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync(t, sq + QPF*j*T + DPF*i, T);
nvcuda::wmma::mma_sync(t2, ms1, t, zr); nvcuda::wmma::mma_sync(t2, ms1, t, zr);
// convert accumulator to matrix_b // convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync( sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); nvcuda::wmma::load_matrix_sync (t, sq + QPF*j*T + DPF*i, T);
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
} }
@ -6779,9 +6785,9 @@ static __global__ void flash_attn_ext_f16(
// store result to shared memory (reuse sq) // store result to shared memory (reuse sq)
if (warp_id == 0) { if (warp_id == 0) {
for (int j = 0; j < Q16; ++j) { for (int j = 0; j < QPT; ++j) {
for (int i = 0; i < D16; ++i) { for (int i = 0; i < DPT; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
} }
} }
} }