diff --git a/ggml-blas.cpp b/ggml-blas.cpp index ade10b9ac..3d146fc01 100644 --- a/ggml-blas.cpp +++ b/ggml-blas.cpp @@ -10,6 +10,9 @@ # include #else # include +# ifdef BLIS_ENABLE_CBLAS +# include +# endif #endif struct ggml_backend_blas_context { @@ -128,6 +131,15 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg #endif } + +#if defined(OPENBLAS_VERSION) + openblas_set_num_threads(ctx->n_threads); +#endif + +#if defined(BLIS_ENABLE_CBLAS) + bli_thread_set_num_threads(ctx->n_threads); +#endif + for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { const int64_t i03 = i13/r3; @@ -322,6 +334,16 @@ ggml_backend_t ggml_backend_blas_init(void) { /* .context = */ ctx, }; +#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) + if (openblas_get_parallel() != OPENBLAS_OPENMP) { + fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); + } +#endif + +#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) + fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); +#endif + return backend; }