Add openBLAS support for sgemm() in compute_forward_out_prod()

This commit is contained in:
gwjr 2023-11-16 18:45:33 +00:00
parent e5c1f02645
commit da122af024

8
ggml.c
View file

@ -9631,9 +9631,9 @@ static void ggml_compute_forward_out_prod_f32(
// compute by src0 rows // compute by src0 rows
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) // TODO: #if defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
bool use_blas = ggml_is_matrix(src0) && bool use_blas = ggml_is_matrix(src0) &&
ggml_is_matrix(src1) && ggml_is_matrix(src1) &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src0) &&
@ -9641,7 +9641,7 @@ static void ggml_compute_forward_out_prod_f32(
#endif #endif
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
#if defined(GGML_USE_ACCELERATE) // gemm beta will zero dst #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
if (use_blas) { if (use_blas) {
return; return;
} }
@ -9654,7 +9654,7 @@ static void ggml_compute_forward_out_prod_f32(
return; return;
} }
#if defined(GGML_USE_ACCELERATE) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (use_blas) { if (use_blas) {
if (params->ith != 0) { // All threads other than the first do no work. if (params->ith != 0) { // All threads other than the first do no work.
return; return;