no vec for hs, no hs==256 ncols==32 for Volta
This commit is contained in:
parent
d59ac670bf
commit
81da919864
2 changed files with 37 additions and 36 deletions
|
@ -141,6 +141,7 @@
|
|||
#define CC_PASCAL 600
|
||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define CC_VOLTA 700
|
||||
#define CC_AMPERE 800
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
|
|
|
@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
float scale;
|
||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||
|
||||
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) {
|
||||
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
|
||||
const int nwarps = Q->ne[0] / WARP_SIZE;
|
||||
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const int shmem = 0;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_vec_ext_f16<64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
// case 64:
|
||||
// flash_attn_vec_ext_f16<64>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
// (const char *) Q->data, // Query
|
||||
// (const char *) K->data, // Key
|
||||
// (const char *) V->data, // Value
|
||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
// (float *) KQV->data, // dst
|
||||
// scale,
|
||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
// K->nb[1], K->nb[2], K->nb[3],
|
||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
// case 80:
|
||||
// flash_attn_vec_ext_f16<80>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
|
@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
case 96:
|
||||
flash_attn_vec_ext_f16<96>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
// case 96:
|
||||
// flash_attn_vec_ext_f16<96>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
// (const char *) Q->data, // Query
|
||||
// (const char *) K->data, // Key
|
||||
// (const char *) V->data, // Value
|
||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
// (float *) KQV->data, // dst
|
||||
// scale,
|
||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
// K->nb[1], K->nb[2], K->nb[3],
|
||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
// case 112:
|
||||
// flash_attn_vec_ext_f16<112>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
|
@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
if (Q->ne[0] % 32 == 0) {
|
||||
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
|
||||
cols_per_block = 64;
|
||||
} else if (Q->ne[1] >= 64) {
|
||||
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
||||
cols_per_block = 32;
|
||||
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
|
||||
cols_per_block = 16;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue