diff --git a/ggml.c b/ggml.c index d59af30b3..67c791b27 100644 --- a/ggml.c +++ b/ggml.c @@ -16570,6 +16570,28 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { //n_tasks = MIN(n_threads, MAX(1, nr0/128)); //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); + +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#elif defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#elif defined(GGML_USE_SYCL) + if (ggml_sycl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; + } +#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif } break; case GGML_OP_MUL_MAT_ID: {