Use cblas_sgemm() to implement ggml_compute_forward_out_prod()
This commit is contained in:
parent
d75eae6333
commit
2f0c5dcaf5
1 changed files with 80 additions and 1 deletions
81
ggml.c
81
ggml.c
|
@ -9598,6 +9598,35 @@ static void ggml_compute_forward_mul_mat(
|
||||||
|
|
||||||
// ggml_compute_forward_out_prod
|
// ggml_compute_forward_out_prod
|
||||||
|
|
||||||
|
#if defined(GGML_USE_ACCELERATE)
|
||||||
|
// helper function to determine if it is better to use BLAS or not
|
||||||
|
// based on ggml_compute_forward_mul_mat_use_blas()
|
||||||
|
// However, testing suggested that BLAS was never slower than the existing code
|
||||||
|
static bool ggml_compute_forward_out_prod_use_blas(
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
UNUSED(dst);
|
||||||
|
// const int64_t ne10 = src1->ne[0];
|
||||||
|
//
|
||||||
|
// const int64_t ne0 = dst->ne[0];
|
||||||
|
// const int64_t ne1 = dst->ne[1];
|
||||||
|
|
||||||
|
if (ggml_is_matrix(src0) &&
|
||||||
|
ggml_is_matrix(src1) &&
|
||||||
|
ggml_is_contiguous(src0) &&
|
||||||
|
(ggml_is_contiguous(src1) || ggml_is_transposed(src1))){ //&&
|
||||||
|
//(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// if (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
|
||||||
|
// printf("Cannot use BLAS for large matrix at %s; ne0: %lld, ne1: %lld, ne10:, %lld", dst->name, ne0, ne1, ne10);
|
||||||
|
// }
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static void ggml_compute_forward_out_prod_f32(
|
static void ggml_compute_forward_out_prod_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
@ -9634,13 +9663,63 @@ static void ggml_compute_forward_out_prod_f32(
|
||||||
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT) {
|
if (params->type == GGML_TASK_INIT) {
|
||||||
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
#if !defined(GGML_USE_ACCELERATE) // gemm beta will do this
|
||||||
|
if (!ggml_compute_forward_out_prod_use_blas(src0, src1, dst)) {
|
||||||
|
#endif
|
||||||
|
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
||||||
|
#if !defined(GGML_USE_ACCELERATE)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(GGML_USE_ACCELERATE)
|
||||||
|
if (ggml_compute_forward_out_prod_use_blas(src0, src1, dst)) {
|
||||||
|
if (params->ith != 0) { // All threads other than the first do no work.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
|
||||||
|
// src0: (k,n)
|
||||||
|
// src1: (k,m)
|
||||||
|
// dst: (m,n)
|
||||||
|
//
|
||||||
|
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
|
||||||
|
// Also expressed as (major,minor)
|
||||||
|
// a: (m,k): so src1 transposed
|
||||||
|
// b: (k,n): so src0
|
||||||
|
// c: (m,n)
|
||||||
|
//
|
||||||
|
// However, if ggml_is_transposed(src1) is true, then
|
||||||
|
// src1->data already contains a transposed version, so sgemm mustn't
|
||||||
|
// transpose it further.
|
||||||
|
|
||||||
|
int n = src0->ne[0];
|
||||||
|
int k = src0->ne[1];
|
||||||
|
int m = src1->ne[0];
|
||||||
|
|
||||||
|
int transposeA, lda;
|
||||||
|
|
||||||
|
if (!ggml_is_transposed(src1)) {
|
||||||
|
transposeA = CblasTrans;
|
||||||
|
lda = m;
|
||||||
|
} else {
|
||||||
|
transposeA = CblasNoTrans;
|
||||||
|
lda = k;
|
||||||
|
}
|
||||||
|
|
||||||
|
float * a = (float *) ((char *) src1->data);
|
||||||
|
float * b = (float *) ((char *) src0->data);
|
||||||
|
float * c = (float *) ((char *) dst->data);
|
||||||
|
|
||||||
|
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// dst[:,:,:,:] = 0
|
// dst[:,:,:,:] = 0
|
||||||
// for i2,i3:
|
// for i2,i3:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue