diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7130209e7..2c050c0c4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16( for (int64_t j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; - if(mp) { + + if (mp) { nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); } @@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - switch (Q->ne[0]) - { - case 16: - flash_attn_ext_f16<16, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) src0_extra->data_device[g_main_device], // Query - (const char *) src1_extra->data_device[g_main_device], // Key - (const char *) src2_extra->data_device[g_main_device], // Value - mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask - (float *) dst_extra->data_device[g_main_device], // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - default: - break; + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 727f2ea06..9feb5e1fe 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,7 +572,7 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; -#if 1 +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } @@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 64, 80, 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {