no vec for hs, no hs==256 ncols==32 for Volta

This commit is contained in:
Johannes Gäßler 2024-03-30 10:34:09 +01:00 committed by Georgi Gerganov
parent d59ac670bf
commit 81da919864
2 changed files with 37 additions and 36 deletions

View file

@ -141,6 +141,7 @@
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)

View file

@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) {
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
const int nwarps = Q->ne[0] / WARP_SIZE;
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const int shmem = 0;
switch (Q->ne[0]) {
case 64:
flash_attn_vec_ext_f16<64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
// case 64:
// flash_attn_vec_ext_f16<64>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
// case 80:
// flash_attn_vec_ext_f16<80>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
case 96:
flash_attn_vec_ext_f16<96>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
// case 96:
// flash_attn_vec_ext_f16<96>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
// case 112:
// flash_attn_vec_ext_f16<112>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
if (Q->ne[0] % 32 == 0) {
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
cols_per_block = 64;
} else if (Q->ne[1] >= 64) {
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
cols_per_block = 32;
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
cols_per_block = 16;