no vec for hs, no hs==256 ncols==32 for Volta

This commit is contained in:
Johannes Gäßler 2024-03-30 10:34:09 +01:00 committed by Georgi Gerganov
parent d59ac670bf
commit 81da919864
2 changed files with 37 additions and 36 deletions

View file

@ -141,6 +141,7 @@
#define CC_PASCAL 600 #define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700 #define CC_VOLTA 700
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000 #define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA2 (CC_OFFSET_AMD + 1030)

View file

@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
float scale; float scale;
memcpy(&scale, KQV->op_params, sizeof(float)); memcpy(&scale, KQV->op_params, sizeof(float));
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
const int nwarps = Q->ne[0] / WARP_SIZE; const int nwarps = Q->ne[0] / WARP_SIZE;
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 block_dim(WARP_SIZE, nwarps, 1);
const int shmem = 0; const int shmem = 0;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: // case 64:
flash_attn_vec_ext_f16<64> // flash_attn_vec_ext_f16<64>
<<<blocks_num, block_dim, shmem, main_stream>>> ( // <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query // (const char *) Q->data, // Query
(const char *) K->data, // Key // (const char *) K->data, // Key
(const char *) V->data, // Value // (const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask // mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst // (float *) KQV->data, // dst
scale, // scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], // K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], // Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], // K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
); // );
break; // break;
// case 80: // case 80:
// flash_attn_vec_ext_f16<80> // flash_attn_vec_ext_f16<80>
// <<<blocks_num, block_dim, shmem, main_stream>>> ( // <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// ); // );
// break; // break;
case 96: // case 96:
flash_attn_vec_ext_f16<96> // flash_attn_vec_ext_f16<96>
<<<blocks_num, block_dim, shmem, main_stream>>> ( // <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query // (const char *) Q->data, // Query
(const char *) K->data, // Key // (const char *) K->data, // Key
(const char *) V->data, // Value // (const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask // mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst // (float *) KQV->data, // dst
scale, // scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], // K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], // Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], // K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
); // );
break; // break;
// case 112: // case 112:
// flash_attn_vec_ext_f16<112> // flash_attn_vec_ext_f16<112>
// <<<blocks_num, block_dim, shmem, main_stream>>> ( // <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
if (Q->ne[0] % 32 == 0) { if (Q->ne[0] % 32 == 0) {
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
cols_per_block = 64; cols_per_block = 64;
} else if (Q->ne[1] >= 64) { } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
cols_per_block = 32; cols_per_block = 32;
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
cols_per_block = 16; cols_per_block = 16;