__shfl_sync workaround for movmatrix

This commit is contained in:
Johannes Gäßler 2025-02-02 12:14:36 +01:00
parent 60958f60ea
commit e3b7c574b1

View file

@ -15,6 +15,44 @@
#include "common.cuh"
#if CUDART_VERSION >= 12000
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
}
#else
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
// Imagine transposing row-major matrix to column-major matrix.
const int src_i_low = 2 * (threadIdx.x % 4);
const int src_i_high = src_i_low + 1;
const int src_j = threadIdx.x / 4;
const int src_laneid_low = src_i_low * 4 + src_j / 2;
const int src_laneid_high = src_i_high * 4 + src_j / 2;
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 12000
template <typename T>
struct mma_A_I16K4 {
static_assert(sizeof(T) == 4, "bad type size");
@ -119,21 +157,14 @@ struct mma_A_I16K8 {
}
__device__ __forceinline__ void transpose() {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(xi[0]) : "r"(xi[0]));
int tmp = 0;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(tmp) : "r"(xi[1]));
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(xi[1]) : "r"(xi[2]));
xi[0] = ggml_cuda_movmatrix(xi[0]);
const int tmp = ggml_cuda_movmatrix(xi[1]);
xi[1] = ggml_cuda_movmatrix(xi[2]);
xi[2] = tmp;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(xi[3]) : "r"(xi[3]));
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
xi[3] = ggml_cuda_movmatrix(xi[3]);
}
};
@ -350,16 +381,10 @@ struct mma_C_I16J8<half2> {
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
int * Bxi = (int *) mma_B.x;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(Bxi[0]) : "r"(xi[0]));
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(Bxi[1]) : "r"(xi[1]));
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
return mma_B;
}
@ -417,15 +442,9 @@ struct mma_C_I16J8<float> {
mma_B.x[0] = make_half2(x[0], x[1]);
mma_B.x[1] = make_half2(x[2], x[3]);
#ifdef NEW_MMA_AVAILABLE
int * Bxi = (int *) mma_B.x;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
: "+r"(Bxi[0]) : );
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
: "+r"(Bxi[1]) : );
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
return mma_B;
}