diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index cd4b3a1e1..d2f1f921b 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -1754,45 +1754,63 @@ namespace dpct const void **b, int ldb, const void *beta, void **c, int ldc, int batch_size) { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; + // TODO: fix issue where alpha could be a device pointer - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + // NOTE: could unify into single memory alloc if we wanted to improve + // performance structured a_trans, b_trans + oneapi::mkl::transpose *trans = + reinterpret_cast( + std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size)); + for (int batch = 0; batch < batch_size; ++batch) { + trans[batch + batch_size * 0] = a_trans; + trans[batch + batch_size * 1] = b_trans; + } - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->transpose_info[0] = a_trans; - matrix_info->transpose_info[1] = b_trans; - matrix_info->value_info[0] = alpha_value; - matrix_info->value_info[1] = beta_value; - matrix_info->size_info[0] = m; - matrix_info->size_info[1] = n; - matrix_info->size_info[2] = k; - matrix_info->ld_info[0] = lda; - matrix_info->ld_info[1] = ldb; - matrix_info->ld_info[2] = ldc; - matrix_info->groupsize_info = batch_size; + // structured m, n, k, lda, ldb, ldc, group_size + int64_t *dims = reinterpret_cast( + std::malloc(sizeof(int64_t) * 7 * batch_size)); + for (int batch = 0; batch < batch_size; ++batch) { + dims[batch + batch_size * 0] = m; + dims[batch + batch_size * 1] = n; + dims[batch + batch_size * 2] = k; - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + dims[batch + batch_size * 3] = lda; + dims[batch + batch_size * 4] = ldb; + dims[batch + batch_size * 5] = ldc; - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); + dims[batch + batch_size * 6] = + 1; // all groups in batch are 1 matrix + } + + // structured alpha, beta + Ts *coeff = + reinterpret_cast(std::malloc(sizeof(Ts) * 2 * batch_size)); + for (int batch = 0; batch < batch_size; ++batch) { + coeff[batch + batch_size * 0] = + *reinterpret_cast(alpha); + coeff[batch + batch_size * 1] = *reinterpret_cast(beta); + } + + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, trans + batch_size * 0 /*a_trans*/, + trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/, + dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/, + coeff + batch_size * 0 /*alpha*/, + reinterpret_cast(a), dims + batch_size * 3 /*lda*/, + reinterpret_cast(b), dims + batch_size * 4 /*ldb*/, + coeff + batch_size * 1 /*beta*/, reinterpret_cast(c), + dims + batch_size * 5 /*ldc*/, batch_size, + dims + batch_size * 6 /*group_size*/); + q.wait_and_throw(); + + q.submit([&](sycl::handler &cgh) { + cgh.depends_on(e); + cgh.host_task([=] { + std::free(trans); + std::free(dims); + std::free(coeff); + }); + }); } template