* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend Move to compile time selection to backend to avoid latency at run time. Add it to all mkl gemm calls and only for NVIDIA backend. Signed-off-by: nscipione <nicolo.scipione@codeplay.com> * Formatting * Address PR comments to increase readibility --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
56 lines
2 KiB
C++
56 lines
2 KiB
C++
#include <sycl/sycl.hpp>
|
|
#include <oneapi/mkl.hpp>
|
|
#include "outprod.hpp"
|
|
|
|
|
|
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
// Get SYCL queue
|
|
dpct::queue_ptr stream = ctx.stream();
|
|
|
|
// Dimension checks
|
|
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
|
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
|
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
|
|
|
// Get data pointers
|
|
const float* src0_d = (const float*)src0->data;
|
|
const float* src1_d = (const float*)src1->data;
|
|
float* dst_d = (float*)dst->data;
|
|
|
|
// GEMM parameters
|
|
const float alpha = 1.0f;
|
|
const float beta = 0.0f;
|
|
|
|
// Handle transposition of src1
|
|
const bool src1_T = ggml_is_transposed(src1);
|
|
const oneapi::mkl::transpose src1_op =
|
|
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
|
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
|
|
|
try {
|
|
// Perform matrix multiplication using oneMKL GEMM
|
|
#ifdef GGML_SYCL_NVIDIA
|
|
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
|
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
|
ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
#else
|
|
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
|
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
#endif
|
|
}
|
|
catch (sycl::exception const& exc) {
|
|
std::cerr << exc.what() << std::endl;
|
|
GGML_ASSERT(false);
|
|
}
|
|
}
|