Fix gemm_batch_impl
This commit is contained in:
parent
0d4177126b
commit
9d05e6a0aa
1 changed files with 53 additions and 35 deletions
|
@ -1754,45 +1754,63 @@ namespace dpct
|
|||
const void **b, int ldb, const void *beta, void **c,
|
||||
int ldc, int batch_size)
|
||||
{
|
||||
struct matrix_info_t
|
||||
{
|
||||
oneapi::mkl::transpose transpose_info[2];
|
||||
Ts value_info[2];
|
||||
std::int64_t size_info[3];
|
||||
std::int64_t ld_info[3];
|
||||
std::int64_t groupsize_info;
|
||||
};
|
||||
// TODO: fix issue where alpha could be a device pointer
|
||||
|
||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
||||
// NOTE: could unify into single memory alloc if we wanted to improve
|
||||
// performance structured a_trans, b_trans
|
||||
oneapi::mkl::transpose *trans =
|
||||
reinterpret_cast<oneapi::mkl::transpose *>(
|
||||
std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
trans[batch + batch_size * 0] = a_trans;
|
||||
trans[batch + batch_size * 1] = b_trans;
|
||||
}
|
||||
|
||||
matrix_info_t *matrix_info =
|
||||
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
||||
matrix_info->transpose_info[0] = a_trans;
|
||||
matrix_info->transpose_info[1] = b_trans;
|
||||
matrix_info->value_info[0] = alpha_value;
|
||||
matrix_info->value_info[1] = beta_value;
|
||||
matrix_info->size_info[0] = m;
|
||||
matrix_info->size_info[1] = n;
|
||||
matrix_info->size_info[2] = k;
|
||||
matrix_info->ld_info[0] = lda;
|
||||
matrix_info->ld_info[1] = ldb;
|
||||
matrix_info->ld_info[2] = ldc;
|
||||
matrix_info->groupsize_info = batch_size;
|
||||
// structured m, n, k, lda, ldb, ldc, group_size
|
||||
int64_t *dims = reinterpret_cast<int64_t *>(
|
||||
std::malloc(sizeof(int64_t) * 7 * batch_size));
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
dims[batch + batch_size * 0] = m;
|
||||
dims[batch + batch_size * 1] = n;
|
||||
dims[batch + batch_size * 2] = k;
|
||||
|
||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
||||
matrix_info->size_info, matrix_info->size_info + 1,
|
||||
matrix_info->size_info + 2, matrix_info->value_info,
|
||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
||||
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
||||
dims[batch + batch_size * 3] = lda;
|
||||
dims[batch + batch_size * 4] = ldb;
|
||||
dims[batch + batch_size * 5] = ldc;
|
||||
|
||||
q.submit([&](sycl::handler &cgh)
|
||||
{
|
||||
cgh.depends_on(e);
|
||||
cgh.host_task([=] { std::free(matrix_info); }); });
|
||||
dims[batch + batch_size * 6] =
|
||||
1; // all groups in batch are 1 matrix
|
||||
}
|
||||
|
||||
// structured alpha, beta
|
||||
Ts *coeff =
|
||||
reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
coeff[batch + batch_size * 0] =
|
||||
*reinterpret_cast<const Ts *>(alpha);
|
||||
coeff[batch + batch_size * 1] = *reinterpret_cast<const Ts *>(beta);
|
||||
}
|
||||
|
||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||
q, trans + batch_size * 0 /*a_trans*/,
|
||||
trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
|
||||
dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
|
||||
coeff + batch_size * 0 /*alpha*/,
|
||||
reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
|
||||
reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
|
||||
coeff + batch_size * 1 /*beta*/, reinterpret_cast<Tc **>(c),
|
||||
dims + batch_size * 5 /*ldc*/, batch_size,
|
||||
dims + batch_size * 6 /*group_size*/);
|
||||
q.wait_and_throw();
|
||||
|
||||
q.submit([&](sycl::handler &cgh) {
|
||||
cgh.depends_on(e);
|
||||
cgh.host_task([=] {
|
||||
std::free(trans);
|
||||
std::free(dims);
|
||||
std::free(coeff);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <class Ta, class Tb, class Tc, class Ts>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue