Fix gemm_batch_impl
This commit is contained in:
parent
0d4177126b
commit
9d05e6a0aa
1 changed files with 53 additions and 35 deletions
|
@ -1754,45 +1754,63 @@ namespace dpct
|
||||||
const void **b, int ldb, const void *beta, void **c,
|
const void **b, int ldb, const void *beta, void **c,
|
||||||
int ldc, int batch_size)
|
int ldc, int batch_size)
|
||||||
{
|
{
|
||||||
struct matrix_info_t
|
// TODO: fix issue where alpha could be a device pointer
|
||||||
{
|
|
||||||
oneapi::mkl::transpose transpose_info[2];
|
|
||||||
Ts value_info[2];
|
|
||||||
std::int64_t size_info[3];
|
|
||||||
std::int64_t ld_info[3];
|
|
||||||
std::int64_t groupsize_info;
|
|
||||||
};
|
|
||||||
|
|
||||||
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
// NOTE: could unify into single memory alloc if we wanted to improve
|
||||||
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
// performance structured a_trans, b_trans
|
||||||
|
oneapi::mkl::transpose *trans =
|
||||||
|
reinterpret_cast<oneapi::mkl::transpose *>(
|
||||||
|
std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
|
||||||
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
|
trans[batch + batch_size * 0] = a_trans;
|
||||||
|
trans[batch + batch_size * 1] = b_trans;
|
||||||
|
}
|
||||||
|
|
||||||
matrix_info_t *matrix_info =
|
// structured m, n, k, lda, ldb, ldc, group_size
|
||||||
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
int64_t *dims = reinterpret_cast<int64_t *>(
|
||||||
matrix_info->transpose_info[0] = a_trans;
|
std::malloc(sizeof(int64_t) * 7 * batch_size));
|
||||||
matrix_info->transpose_info[1] = b_trans;
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
matrix_info->value_info[0] = alpha_value;
|
dims[batch + batch_size * 0] = m;
|
||||||
matrix_info->value_info[1] = beta_value;
|
dims[batch + batch_size * 1] = n;
|
||||||
matrix_info->size_info[0] = m;
|
dims[batch + batch_size * 2] = k;
|
||||||
matrix_info->size_info[1] = n;
|
|
||||||
matrix_info->size_info[2] = k;
|
|
||||||
matrix_info->ld_info[0] = lda;
|
|
||||||
matrix_info->ld_info[1] = ldb;
|
|
||||||
matrix_info->ld_info[2] = ldc;
|
|
||||||
matrix_info->groupsize_info = batch_size;
|
|
||||||
|
|
||||||
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
dims[batch + batch_size * 3] = lda;
|
||||||
q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
dims[batch + batch_size * 4] = ldb;
|
||||||
matrix_info->size_info, matrix_info->size_info + 1,
|
dims[batch + batch_size * 5] = ldc;
|
||||||
matrix_info->size_info + 2, matrix_info->value_info,
|
|
||||||
reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
||||||
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
||||||
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
|
|
||||||
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
||||||
|
|
||||||
q.submit([&](sycl::handler &cgh)
|
dims[batch + batch_size * 6] =
|
||||||
{
|
1; // all groups in batch are 1 matrix
|
||||||
cgh.depends_on(e);
|
}
|
||||||
cgh.host_task([=] { std::free(matrix_info); }); });
|
|
||||||
|
// structured alpha, beta
|
||||||
|
Ts *coeff =
|
||||||
|
reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
|
||||||
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
|
coeff[batch + batch_size * 0] =
|
||||||
|
*reinterpret_cast<const Ts *>(alpha);
|
||||||
|
coeff[batch + batch_size * 1] = *reinterpret_cast<const Ts *>(beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
||||||
|
q, trans + batch_size * 0 /*a_trans*/,
|
||||||
|
trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
|
||||||
|
dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
|
||||||
|
coeff + batch_size * 0 /*alpha*/,
|
||||||
|
reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
|
||||||
|
reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
|
||||||
|
coeff + batch_size * 1 /*beta*/, reinterpret_cast<Tc **>(c),
|
||||||
|
dims + batch_size * 5 /*ldc*/, batch_size,
|
||||||
|
dims + batch_size * 6 /*group_size*/);
|
||||||
|
q.wait_and_throw();
|
||||||
|
|
||||||
|
q.submit([&](sycl::handler &cgh) {
|
||||||
|
cgh.depends_on(e);
|
||||||
|
cgh.host_task([=] {
|
||||||
|
std::free(trans);
|
||||||
|
std::free(dims);
|
||||||
|
std::free(coeff);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb, class Tc, class Ts>
|
template <class Ta, class Tb, class Tc, class Ts>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue