cuda : express strides with helper constants
This commit is contained in:
parent
1846e92a90
commit
a647257b47
1 changed files with 63 additions and 57 deletions
120
ggml-cuda.cu
120
ggml-cuda.cu
|
@ -6399,10 +6399,16 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
// queries, dims, cache per fragment
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
#define QPF 16
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
#define DPF 16
|
||||||
|
#define CPF 16
|
||||||
|
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, QPF, CPF, DPF, half, nvcuda::wmma::row_major> half16x16_a;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, QPF, CPF, DPF, half, nvcuda::wmma::row_major> half16x16_b;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, QPF, CPF, DPF, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, QPF, CPF, DPF, half> half16x16_acc;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
|
@ -6443,9 +6449,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int iq2 = blockIdx.y;
|
const int iq2 = blockIdx.y;
|
||||||
const int iq1 = blockIdx.x * Q;
|
const int iq1 = blockIdx.x * Q;
|
||||||
|
|
||||||
const int D16 = D/16;
|
const int DPT = D/DPF; // dims per thread
|
||||||
const int Q16 = Q/16;
|
const int QPT = Q/QPF; // queries per thread
|
||||||
const int C16 = C/16;
|
const int CPT = C/CPF; // cache per thread
|
||||||
|
|
||||||
const int NW = WARP_SIZE;
|
const int NW = WARP_SIZE;
|
||||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
@ -6463,7 +6469,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
||||||
|
|
||||||
half16x16_acc zr;
|
half16x16_acc zr;
|
||||||
half16x16_acc lo[Q16][D16];
|
half16x16_acc lo[QPT][DPT];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -6493,8 +6499,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
nvcuda::wmma::fill_fragment(zr, 0.0);
|
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||||
|
|
||||||
// zero out lo
|
// zero out lo
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6545,10 +6551,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int iv3 = iq3 / rv3;
|
const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
half16x16_a mq[Q16][D16];
|
half16x16_a mq[QPT][DPT];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + QPF*j*T + DPF*i, T);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6557,7 +6563,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
half16x16_b mscale;
|
half16x16_b mscale;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < QPF; ++i) {
|
||||||
ss[i*T + i] = __float2half(scale);
|
ss[i*T + i] = __float2half(scale);
|
||||||
}
|
}
|
||||||
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
||||||
|
@ -6573,38 +6579,38 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int cc = 0; cc < C16; ++cc) {
|
for (int cc = 0; cc < CPT; ++cc) {
|
||||||
half16x16_acc mqk[Q16];
|
half16x16_acc mqk[QPT];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
const half * pk = (const half *) ((const char *) k + ((ic + CPF*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
half16x16_bT mk; // transposed key
|
half16x16_bT mk; // transposed key
|
||||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mk, pk + DPF*i, nb11/sizeof(half));
|
||||||
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
half16x16_a mqka;
|
half16x16_a mqka;
|
||||||
half16x16_acc mm;
|
half16x16_acc mm;
|
||||||
|
|
||||||
if (mp) {
|
if (mp) {
|
||||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::load_matrix_sync(mm, mp + QPF*j*(nb31/sizeof(half)) + ic + CPF*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert accumulator to matrix_a
|
// convert accumulator to matrix_a
|
||||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
nvcuda::wmma::load_matrix_sync (mqka, ss + QPF*j*T + CPF*cc, T);
|
||||||
|
|
||||||
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
||||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + CPF*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6664,43 +6670,43 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
half16x16_a mm;
|
half16x16_a mm;
|
||||||
half16x16_b lob;
|
half16x16_b lob;
|
||||||
|
|
||||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync(mm, ss + QPF*j*T + C + QPF*j, T);
|
||||||
|
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
// convert accumulator to matrix_b
|
// convert accumulator to matrix_b
|
||||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync( ss + QPF*j*T + C + QPF*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync (lob, ss + QPF*j*T + C + QPF*j, T);
|
||||||
|
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore zeros
|
// restore zeros
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(ss + QPF*j*T + C + QPF*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
for (int cc = 0; cc < C16; ++cc) {
|
for (int cc = 0; cc < CPT; ++cc) {
|
||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + CPF*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half16x16_b mv[D16];
|
half16x16_b mv[DPT];
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mv[i], pv + DPF*i, nb21/sizeof(half));
|
||||||
}
|
}
|
||||||
|
|
||||||
half16x16_a ms[Q16];
|
half16x16_a ms[QPT];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
nvcuda::wmma::load_matrix_sync(ms[j], ss + QPF*j*T + CPF*cc, T);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6721,9 +6727,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
if (warp_id == sg) {
|
if (warp_id == sg) {
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6754,22 +6760,22 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
half16x16_a ms0;
|
half16x16_a ms0;
|
||||||
half16x16_a ms1;
|
half16x16_a ms1;
|
||||||
half16x16_b t;
|
half16x16_b t;
|
||||||
half16x16_acc t2;
|
half16x16_acc t2;
|
||||||
|
|
||||||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync(ms0, ss + QPF*j*T + C + QPF*j, T);
|
||||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
nvcuda::wmma::load_matrix_sync(ms1, ss + QPF*j*T + C + QPF*j + sg*SH, T);
|
||||||
|
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync(t, sq + QPF*j*T + DPF*i, T);
|
||||||
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||||
|
|
||||||
// convert accumulator to matrix_b
|
// convert accumulator to matrix_b
|
||||||
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync( sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync (t, sq + QPF*j*T + DPF*i, T);
|
||||||
|
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
||||||
}
|
}
|
||||||
|
@ -6779,9 +6785,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// store result to shared memory (reuse sq)
|
// store result to shared memory (reuse sq)
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < QPT; ++j) {
|
||||||
for (int i = 0; i < D16; ++i) {
|
for (int i = 0; i < DPT; ++i) {
|
||||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(sq + QPF*j*T + DPF*i, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue