Fix gemm_batch_impl

This commit is contained in:
Aidan 2024-02-15 11:11:19 +00:00
parent 0d4177126b
commit 9d05e6a0aa

View file

@ -1754,45 +1754,63 @@ namespace dpct
const void **b, int ldb, const void *beta, void **c, const void **b, int ldb, const void *beta, void **c,
int ldc, int batch_size) int ldc, int batch_size)
{ {
struct matrix_info_t // TODO: fix issue where alpha could be a device pointer
{
oneapi::mkl::transpose transpose_info[2];
Ts value_info[2];
std::int64_t size_info[3];
std::int64_t ld_info[3];
std::int64_t groupsize_info;
};
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); // NOTE: could unify into single memory alloc if we wanted to improve
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); // performance structured a_trans, b_trans
oneapi::mkl::transpose *trans =
reinterpret_cast<oneapi::mkl::transpose *>(
std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
for (int batch = 0; batch < batch_size; ++batch) {
trans[batch + batch_size * 0] = a_trans;
trans[batch + batch_size * 1] = b_trans;
}
matrix_info_t *matrix_info = // structured m, n, k, lda, ldb, ldc, group_size
(matrix_info_t *)std::malloc(sizeof(matrix_info_t)); int64_t *dims = reinterpret_cast<int64_t *>(
matrix_info->transpose_info[0] = a_trans; std::malloc(sizeof(int64_t) * 7 * batch_size));
matrix_info->transpose_info[1] = b_trans; for (int batch = 0; batch < batch_size; ++batch) {
matrix_info->value_info[0] = alpha_value; dims[batch + batch_size * 0] = m;
matrix_info->value_info[1] = beta_value; dims[batch + batch_size * 1] = n;
matrix_info->size_info[0] = m; dims[batch + batch_size * 2] = k;
matrix_info->size_info[1] = n;
matrix_info->size_info[2] = k;
matrix_info->ld_info[0] = lda;
matrix_info->ld_info[1] = ldb;
matrix_info->ld_info[2] = ldc;
matrix_info->groupsize_info = batch_size;
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( dims[batch + batch_size * 3] = lda;
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, dims[batch + batch_size * 4] = ldb;
matrix_info->size_info, matrix_info->size_info + 1, dims[batch + batch_size * 5] = ldc;
matrix_info->size_info + 2, matrix_info->value_info,
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
q.submit([&](sycl::handler &cgh) dims[batch + batch_size * 6] =
{ 1; // all groups in batch are 1 matrix
cgh.depends_on(e); }
cgh.host_task([=] { std::free(matrix_info); }); });
// structured alpha, beta
Ts *coeff =
reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
for (int batch = 0; batch < batch_size; ++batch) {
coeff[batch + batch_size * 0] =
*reinterpret_cast<const Ts *>(alpha);
coeff[batch + batch_size * 1] = *reinterpret_cast<const Ts *>(beta);
}
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
q, trans + batch_size * 0 /*a_trans*/,
trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
coeff + batch_size * 0 /*alpha*/,
reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
coeff + batch_size * 1 /*beta*/, reinterpret_cast<Tc **>(c),
dims + batch_size * 5 /*ldc*/, batch_size,
dims + batch_size * 6 /*group_size*/);
q.wait_and_throw();
q.submit([&](sycl::handler &cgh) {
cgh.depends_on(e);
cgh.host_task([=] {
std::free(trans);
std::free(dims);
std::free(coeff);
});
});
} }
template <class Ta, class Tb, class Tc, class Ts> template <class Ta, class Tb, class Tc, class Ts>