Fix missing MMQ ops when on hipBLAS

I had put the ggml_supports_mmq call at the wrong place.
This commit is contained in:
Iwan Kawrakow 2024-01-08 10:09:50 +01:00
parent 47ae9b8f34
commit 61c04053a4

View file

@ -8894,10 +8894,10 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE); use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
#endif // CUDA_USE_TENSOR_CORES #endif // CUDA_USE_TENSOR_CORES
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
// debug helpers // debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);