SYCL gemm.hpp: use const cast to properly support dnnl::memory

This commit is contained in:
Akarshan Biswas 2024-12-11 11:27:13 +05:30
parent 274842d976
commit b0e27ad9ec
No known key found for this signature in database
GPG key ID: 52A578A14B32134D

View file

@ -52,8 +52,8 @@ public:
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
@ -80,8 +80,8 @@ public:
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);