tests : update dims
This commit is contained in:
parent
db1f3c482e
commit
12eaa22628
2 changed files with 110 additions and 76 deletions
180
ggml-cuda.cu
180
ggml-cuda.cu
|
@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
half16x16_a mqka;
|
||||
half16x16_acc mm;
|
||||
if(mp) {
|
||||
|
||||
if (mp) {
|
||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
|
@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
|
||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||
|
||||
switch (Q->ne[0])
|
||||
{
|
||||
case 16:
|
||||
flash_attn_ext_f16<16, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 80:
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 128:
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 80:
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 96:
|
||||
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 112:
|
||||
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 128:
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 256:
|
||||
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue