SYCL gemm.hpp: use const cast to properly support dnnl::memory
This commit is contained in:
parent
274842d976
commit
b0e27ad9ec
1 changed files with 4 additions and 4 deletions
|
@ -52,8 +52,8 @@ public:
|
||||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
||||||
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
||||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
@ -80,8 +80,8 @@ public:
|
||||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
|
||||||
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
|
||||||
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue