diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c1b179c6b..03ecdee7c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -74,7 +74,7 @@ #include "ggml.h" #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_TURING 700 +#define CC_TURING 1000000000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -90,24 +90,18 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) { static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(__gfx1010__)// || defined(__gfx900__) - int ashift; - int bshift; - int aext; - int bext; +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; asm("\n \ - v_pk_ashrrev_i16 %1, 0x80008, %5 \n \ - v_pk_ashrrev_i16 %2, 0x80008, %6 \n \ - v_mov_b32_sdwa %3, sext(%5) dst_sel:WORD_1 src0_sel:BYTE_2 \n \ - v_mov_b32_sdwa %3, sext(%5) dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:BYTE_0 \n \ - v_mov_b32_sdwa %4, sext(%6) dst_sel:WORD_1 src0_sel:BYTE_2 \n \ - v_mov_b32_sdwa %4, sext(%6) dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:BYTE_0 \n \ - v_mad_i32_i16 %0, %1, %2, %0 op_sel:[0, 0, 0, 0] \n \ - v_mad_i32_i16 %0, %1, %2, %0 op_sel:[1, 1, 0, 0] \n \ - v_mad_i32_i16 %0, %3, %4, %0 op_sel:[0, 0, 0, 0] \n \ - v_mad_i32_i16 %0, %3, %4, %0 op_sel:[1, 1, 0, 0] \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ " - : "+v"(c), "=&v"(ashift), "=&v"(bshift), "=&v"(aext), "=&v"(bext) + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) : "v"(a), "v"(b) ); #else