wip
This commit is contained in:
parent
1846e92a90
commit
49a483e0f2
1 changed files with 12 additions and 9 deletions
21
ggml-cuda.cu
21
ggml-cuda.cu
|
@ -6565,7 +6565,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
|
||||
const int ic = ic0 + warp_id*C;
|
||||
const int ic = ic0 + warp_id*16;
|
||||
if (ic >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||
}
|
||||
|
||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*num_warps*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
half16x16_bT mk; // transposed key
|
||||
|
@ -6596,7 +6596,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
half16x16_acc mm;
|
||||
|
||||
if (mp) {
|
||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*num_warps*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
// convert accumulator to matrix_a
|
||||
|
@ -6686,22 +6686,25 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mv[D16];
|
||||
half16x16_b mv[C16][D16];
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*num_warps*cc)*256 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(mv[cc][i], pv + i*16, 256/sizeof(half));
|
||||
}
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
half16x16_a ms[Q16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[cc][i], lo[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue