no vec for hs, no hs==256 ncols==32 for Volta
This commit is contained in:
parent
d59ac670bf
commit
81da919864
2 changed files with 37 additions and 36 deletions
|
@ -141,6 +141,7 @@
|
||||||
#define CC_PASCAL 600
|
#define CC_PASCAL 600
|
||||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||||
#define CC_VOLTA 700
|
#define CC_VOLTA 700
|
||||||
|
#define CC_AMPERE 800
|
||||||
#define CC_OFFSET_AMD 1000000
|
#define CC_OFFSET_AMD 1000000
|
||||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||||
|
|
|
@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) {
|
if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) {
|
||||||
const int nwarps = Q->ne[0] / WARP_SIZE;
|
const int nwarps = Q->ne[0] / WARP_SIZE;
|
||||||
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
// case 64:
|
||||||
flash_attn_vec_ext_f16<64>
|
// flash_attn_vec_ext_f16<64>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data, // Query
|
// (const char *) Q->data, // Query
|
||||||
(const char *) K->data, // Key
|
// (const char *) K->data, // Key
|
||||||
(const char *) V->data, // Value
|
// (const char *) V->data, // Value
|
||||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
(float *) KQV->data, // dst
|
// (float *) KQV->data, // dst
|
||||||
scale,
|
// scale,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
// K->nb[1], K->nb[2], K->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
// );
|
||||||
break;
|
// break;
|
||||||
// case 80:
|
// case 80:
|
||||||
// flash_attn_vec_ext_f16<80>
|
// flash_attn_vec_ext_f16<80>
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
// );
|
// );
|
||||||
// break;
|
// break;
|
||||||
case 96:
|
// case 96:
|
||||||
flash_attn_vec_ext_f16<96>
|
// flash_attn_vec_ext_f16<96>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data, // Query
|
// (const char *) Q->data, // Query
|
||||||
(const char *) K->data, // Key
|
// (const char *) K->data, // Key
|
||||||
(const char *) V->data, // Value
|
// (const char *) V->data, // Value
|
||||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||||
(float *) KQV->data, // dst
|
// (float *) KQV->data, // dst
|
||||||
scale,
|
// scale,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
// K->nb[1], K->nb[2], K->nb[3],
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
// );
|
||||||
break;
|
// break;
|
||||||
// case 112:
|
// case 112:
|
||||||
// flash_attn_vec_ext_f16<112>
|
// flash_attn_vec_ext_f16<112>
|
||||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
if (Q->ne[0] % 32 == 0) {
|
if (Q->ne[0] % 32 == 0) {
|
||||||
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
|
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
|
||||||
cols_per_block = 64;
|
cols_per_block = 64;
|
||||||
} else if (Q->ne[1] >= 64) {
|
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
||||||
cols_per_block = 32;
|
cols_per_block = 32;
|
||||||
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
|
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
|
||||||
cols_per_block = 16;
|
cols_per_block = 16;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue