Using dp4a ptx intrinsics for an improved Mul8MAT perf [By Alcpz]
This commit is contained in:
parent
439b3fc75a
commit
eab4a88210
1 changed files with 15 additions and 0 deletions
|
@ -1834,6 +1834,20 @@ namespace dpct
|
||||||
template <typename T1, typename T2, typename T3>
|
template <typename T1, typename T2, typename T3>
|
||||||
inline auto dp4a(T1 a, T2 b, T3 c)
|
inline auto dp4a(T1 a, T2 b, T3 c)
|
||||||
{
|
{
|
||||||
|
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
|
||||||
|
defined(__SYCL_CUDA_ARCH__) && __SYCL_CUDA_ARCH__ >= 610
|
||||||
|
dot_product_acc_t<T1, T2> res;
|
||||||
|
if constexpr (std::is_same_v<dot_product_acc_t<T1, T2>, uint32_t>) {
|
||||||
|
asm volatile("dp4a.u32.u32 %0, %1, %2, %3;"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "r"(b), "r"(c));
|
||||||
|
} else {
|
||||||
|
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "r"(b), "r"(c));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
#else
|
||||||
dot_product_acc_t<T1, T2> res = c;
|
dot_product_acc_t<T1, T2> res = c;
|
||||||
auto va = extract_and_sign_or_zero_extend4(a);
|
auto va = extract_and_sign_or_zero_extend4(a);
|
||||||
auto vb = extract_and_sign_or_zero_extend4(b);
|
auto vb = extract_and_sign_or_zero_extend4(b);
|
||||||
|
@ -1842,6 +1856,7 @@ namespace dpct
|
||||||
res += va[2] * vb[2];
|
res += va[2] * vb[2];
|
||||||
res += va[3] * vb[3];
|
res += va[3] * vb[3];
|
||||||
return res;
|
return res;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
struct sub_sat
|
struct sub_sat
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue