tests : update dims
This commit is contained in:
parent
db1f3c482e
commit
12eaa22628
2 changed files with 110 additions and 76 deletions
74
ggml-cuda.cu
74
ggml-cuda.cu
|
@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int64_t j = 0; j < Q16; ++j) {
|
||||||
half16x16_a mqka;
|
half16x16_a mqka;
|
||||||
half16x16_acc mm;
|
half16x16_acc mm;
|
||||||
if(mp) {
|
|
||||||
|
if (mp) {
|
||||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10927,25 +10928,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
|
|
||||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||||
|
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0]) {
|
||||||
{
|
|
||||||
case 16:
|
|
||||||
flash_attn_ext_f16<16, NQPB, NCPW>
|
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
|
||||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
|
||||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
|
||||||
(float *) dst_extra->data_device[g_main_device], // dst
|
|
||||||
scale,
|
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -10980,6 +10963,40 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
case 96:
|
||||||
|
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -10997,6 +11014,23 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
|
case 256:
|
||||||
|
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -572,7 +572,7 @@ struct test_case {
|
||||||
// duplicate the op
|
// duplicate the op
|
||||||
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
||||||
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
||||||
#if 1
|
#if 0
|
||||||
for (int i = 1; i < n_runs; i++) {
|
for (int i = 1; i < n_runs; i++) {
|
||||||
gf->nodes[gf->n_nodes++] = out;
|
gf->nodes[gf->n_nodes++] = out;
|
||||||
}
|
}
|
||||||
|
@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
#if 0
|
#if 1
|
||||||
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
|
for (int hs : { 64, 80, 128, }) {
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, 2048, 4096, }) {
|
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue