From 77f88e350e10a7c0b5ded8ffeac7ae9504ad545e Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 6 Jun 2024 01:40:43 +0200 Subject: [PATCH] add support for out_prod --- CMakeLists.txt | 2 +- ggml-blas.c | 124 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 98 insertions(+), 28 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e4eaed070..6e5baa6a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,8 +92,8 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT}) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") +option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT}) option(LLAMA_CUDA "llama: use CUDA" OFF) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) diff --git a/ggml-blas.c b/ggml-blas.c index 6d527c041..f826e7ab8 100644 --- a/ggml-blas.c +++ b/ggml-blas.c @@ -5,12 +5,10 @@ #if defined(GGML_USE_ACCELERATE) # include -#elif defined(GGML_USE_BLAS) -# if defined(GGML_BLAS_USE_MKL) -# include -# else -# include -# endif +#elif defined(GGML_BLAS_USE_MKL) +# include +#else +# include #endif struct ggml_backend_blas_context { @@ -21,7 +19,7 @@ struct ggml_backend_blas_context { // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster -static bool ggml_compute_forward_mul_mat_use_blas(const struct ggml_tensor * dst) { +static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -72,11 +70,8 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - const int64_t ne_plane = ne01*ne00; - const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne13*ne12*ne_plane*sizeof(float); + const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); if (ctx->work_size < desired_wsize) { free(ctx->work_data); @@ -87,21 +82,19 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st void * wdata = ctx->work_data; // convert src0 to float - if (true) { - if (type != GGML_TYPE_F32) { - ggml_to_float_t const to_float = type_traits.to_float; + if (type != GGML_TYPE_F32) { + ggml_to_float_t const to_float = type_traits.to_float; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wplane = (float *) wdata + i03*ne12*ne_plane + i02*ne_plane; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i03*ne12*ne_plane + i02*ne_plane; #ifdef GGML_USE_OPENMP #pragma omp parallel for num_threads(ctx->n_threads) #endif - for (int64_t i01 = 0; i01 < ne01; i01++) { - to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); - } + for (int64_t i01 = 0; i01 < ne01; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); } } } @@ -129,6 +122,70 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st } } +static void ggml_backend_blas_out_prod(struct ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) + // src0: (k,n) + // src1: (k,m) + // dst: (m,n) + // + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) + // Also expressed as (major,minor) + // a: (m,k): so src1 transposed + // b: (k,n): so src0 + // c: (m,n) + // + // However, if ggml_is_transposed(src1) is true, then + // src1->data already contains a transposed version, so sgemm mustn't + // transpose it further. + + int n = src0->ne[0]; + int k = src0->ne[1]; + int m = src1->ne[0]; + + int transposeA; + int lda; + + if (!ggml_is_transposed(src1)) { + transposeA = CblasTrans; + lda = m; + } else { + transposeA = CblasNoTrans; + lda = k; + } + + float * a = (float *) ((char *) src1->data); + float * b = (float *) ((char *) src0->data); + float * c = (float *) ((char *) dst->data); + + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); + + GGML_UNUSED(ctx); +} + // backend interface GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { @@ -138,6 +195,9 @@ GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { } GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { + struct ggml_backend_blas_context * ctx = (struct ggml_backend_blas_context *)backend->context; + free(ctx->work_data); + free(ctx); free(backend); } @@ -158,8 +218,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t ggml_backend_blas_mul_mat(ctx, node); break; - // TODO - //case GGML_OP_OUT_PROD: + case GGML_OP_OUT_PROD: + ggml_backend_blas_out_prod(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -180,7 +241,16 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t } GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return op->op == GGML_OP_MUL_MAT && ggml_compute_forward_mul_mat_use_blas(op); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) || + (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1))); GGML_UNUSED(backend); } @@ -229,9 +299,9 @@ ggml_backend_t ggml_backend_blas_init(void) { return NULL; } - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->work_data = NULL; - ctx->work_size = 0; + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; *backend = (struct ggml_backend) { /* .guid = */ ggml_backend_blas_guid(),