fix kernel
This commit is contained in:
parent
3b0f74b428
commit
b1479dfbc5
2 changed files with 56 additions and 49 deletions
93
ggml-cuda.cu
93
ggml-cuda.cu
|
@ -6158,9 +6158,9 @@ static __global__ void flash_attn_f32(
|
|||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_bT;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
|
||||
// based on metal version
|
||||
|
@ -6204,7 +6204,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int D16 = D/16;
|
||||
const int Q16 = Q/16;
|
||||
const int NW = WARP_SIZE;
|
||||
const int SH = (C + D); // shared memory per simdgroup in (half)
|
||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||
const int T2 = T/2; // shared memory size per query in (half2)
|
||||
|
@ -6288,7 +6288,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||
|
||||
// pointer to the mask
|
||||
const float * mp = (const float *) (mask + (ir%ne31)*nb31);
|
||||
const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr;
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
|
@ -6305,7 +6305,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
half16x16_bT mk; // transposed key
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
|
@ -6314,14 +6314,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// mqk = mqk*scale + mask
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
|
||||
int64_t msk_ne_row = nb31/sizeof(float);
|
||||
// const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
|
||||
// int64_t msk_ne_row = nb31/sizeof(float);
|
||||
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
|
||||
int msk_col = i % 16;
|
||||
int msk_row = i / 16;
|
||||
mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]);
|
||||
// int msk_col = i % 16;
|
||||
// int msk_row = i / 16;
|
||||
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
|
||||
}
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major);
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6370,11 +6370,11 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (lane_id == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
ss[j*T + C + j] = __float2half(ms);
|
||||
}
|
||||
|
||||
for (int64_t p = lane_id; p < C; p += NW) {
|
||||
const float s = ss[j*T + p];
|
||||
const float s = __half2float(ss[j*T + p]);
|
||||
|
||||
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
|
||||
|
||||
|
@ -6393,14 +6393,18 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// O = diag(ms)*O
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
half16x16_a mm;
|
||||
half16x16_b zro;
|
||||
// half16x16_a mm;
|
||||
// half16x16_b zro;
|
||||
|
||||
nvcuda::wmma::fill_fragment(zro, 0.0);
|
||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||
// nvcuda::wmma::fill_fragment(zro, 0.0);
|
||||
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
|
||||
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
|
||||
for (uint32_t k = 0; k < 16*16; k++) {
|
||||
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
|
||||
lo[j][i].x[k] = tmp * lo[j][i].x[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6444,7 +6448,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (warp_id == sg) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6487,13 +6491,13 @@ static __global__ void flash_attn_ext_f16(
|
|||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(t2, 0.0);
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
|
||||
|
||||
// t <- lo
|
||||
for (uint32_t k = 0; k < t.num_elements; k++) {
|
||||
t.x[k] = lo[j][i].x[k];
|
||||
}
|
||||
// store temporally 'lo' data
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
// load 'lo' data into t
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
||||
}
|
||||
}
|
||||
|
@ -6504,22 +6508,20 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major);
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
// float2 * dst2 = (float2 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const float S = __half2float(ss[j*T + 0]);
|
||||
|
||||
for (int64_t i = lane_id; i < D2; i += NW) {
|
||||
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]);
|
||||
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S;
|
||||
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S;
|
||||
for (int64_t i = lane_id; i < D; i += NW) {
|
||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10526,13 +10528,17 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
if(mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F32);
|
||||
}
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||
if(mask) {
|
||||
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
|
||||
}
|
||||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
|
@ -10541,7 +10547,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra;
|
||||
ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||
|
||||
float scale;
|
||||
|
@ -10549,13 +10555,14 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
|
||||
const int nqpb = 16; // queries per block
|
||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
|
||||
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
|
||||
const int nwarps = 1;
|
||||
|
||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||
dim3 block_dim(32, nwarps, 1);
|
||||
|
||||
int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2);
|
||||
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
|
||||
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||
printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale);
|
||||
switch (Q->ne[0])
|
||||
{
|
||||
case 16:
|
||||
|
@ -10564,12 +10571,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask->ne[1], mask->nb[1],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
|
@ -10581,12 +10588,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask->ne[1], mask->nb[1],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
|
@ -10598,12 +10605,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask->ne[1], mask->nb[1],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
|
@ -10615,12 +10622,12 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask->ne[1], mask->nb[1],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
|
|
|
@ -201,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
|
|||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
if(!model.naive_attn) {
|
||||
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0]));
|
||||
struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, nullptr, 1.0f / sqrtf(model.q->ne[0]));
|
||||
ggml_build_forward_expand(gf, result);
|
||||
} else {
|
||||
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue