llama.cpp/ggml/src/ggml-sycl/outprod.cpp
Nicolò Scipione 40c6d79fb5
SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (#10584)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend

Move to compile time selection to backend to avoid latency at run time.
Add it to all mkl gemm calls and only for NVIDIA backend.

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>

* Formatting

* Address PR comments to increase readibility

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
2024-12-04 09:29:20 +08:00

56 lines
2 KiB
C++

#include <sycl/sycl.hpp>
#include <oneapi/mkl.hpp>
#include "outprod.hpp"
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
const ggml_tensor* src1, ggml_tensor* dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_TENSOR_BINARY_OP_LOCALS
// Get SYCL queue
dpct::queue_ptr stream = ctx.stream();
// Dimension checks
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
// Get data pointers
const float* src0_d = (const float*)src0->data;
const float* src1_d = (const float*)src1->data;
float* dst_d = (float*)dst->data;
// GEMM parameters
const float alpha = 1.0f;
const float beta = 0.0f;
// Handle transposition of src1
const bool src1_T = ggml_is_transposed(src1);
const oneapi::mkl::transpose src1_op =
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
try {
// Perform matrix multiplication using oneMKL GEMM
#ifdef GGML_SYCL_NVIDIA
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
ne00, src1_d, ldb, beta, dst_d, ne0);
#else
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
#endif
}
catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl;
GGML_ASSERT(false);
}
}