MUSA adds support for __vsubss4

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
Xiaodong Ye 2024-07-23 19:37:38 +08:00
parent 27396eac46
commit f4a9303e49

View file

@ -259,7 +259,6 @@
#define cublasComputeType_t cudaDataType_t
// XXX: Clang builtins mapping
#define __vsubss4 __vsubss4_musa
#define __vsub4 __vsub4_musa
#define __vcmpeq4 __vcmpeq4_musa
#define __vcmpne4 __vcmpne4_musa
@ -372,30 +371,10 @@ typedef float2 dfloat2;
#define __has_builtin(x) 0
#endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int &>(c);
#else
int8x4_t c;
int16_t tmp;
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp = va[i] - vb[i];
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int &>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
return __vsubss4_musa(a, b);
return __vsubss4(a, b);
}
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {