diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index fbfd15527..997f83e2a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -15,6 +15,44 @@ #include "common.cuh" + +#if CUDART_VERSION >= 12000 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef NEW_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(ret) : "r"(x)); +#else + NO_DEVICE_CODE; +#endif // defined(NEW_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 12000 + + template struct mma_A_I16K4 { static_assert(sizeof(T) == 4, "bad type size"); @@ -119,21 +157,14 @@ struct mma_A_I16K8 { } __device__ __forceinline__ void transpose() { -#ifdef NEW_MMA_AVAILABLE int * xi = (int *) x; - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(xi[0]) : "r"(xi[0])); - int tmp = 0; - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(tmp) : "r"(xi[1])); - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(xi[1]) : "r"(xi[2])); + xi[0] = ggml_cuda_movmatrix(xi[0]); + + const int tmp = ggml_cuda_movmatrix(xi[1]); + xi[1] = ggml_cuda_movmatrix(xi[2]); xi[2] = tmp; - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(xi[3]) : "r"(xi[3])); -#else - NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE + + xi[3] = ggml_cuda_movmatrix(xi[3]); } }; @@ -350,16 +381,10 @@ struct mma_C_I16J8 { __device__ __forceinline__ mma_B_J8K8 to_mma_B() { mma_B_J8K8 mma_B; -#ifdef NEW_MMA_AVAILABLE int * xi = (int *) x; int * Bxi = (int *) mma_B.x; - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(Bxi[0]) : "r"(xi[0])); - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(Bxi[1]) : "r"(xi[1])); -#else - NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE + Bxi[0] = ggml_cuda_movmatrix(xi[0]); + Bxi[1] = ggml_cuda_movmatrix(xi[1]); return mma_B; } @@ -417,15 +442,9 @@ struct mma_C_I16J8 { mma_B.x[0] = make_half2(x[0], x[1]); mma_B.x[1] = make_half2(x[2], x[3]); -#ifdef NEW_MMA_AVAILABLE int * Bxi = (int *) mma_B.x; - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" - : "+r"(Bxi[0]) : ); - asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" - : "+r"(Bxi[1]) : ); -#else - NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE + Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); + Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); return mma_B; }