__shfl_sync workaround for movmatrix
This commit is contained in:
parent
60958f60ea
commit
e3b7c574b1
1 changed files with 48 additions and 29 deletions
|
@ -15,6 +15,44 @@
|
|||
|
||||
#include "common.cuh"
|
||||
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
int ret = 0;
|
||||
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(ret) : "r"(x));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(NEW_MMA_AVAILABLE)
|
||||
return ret;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
// Imagine transposing row-major matrix to column-major matrix.
|
||||
const int src_i_low = 2 * (threadIdx.x % 4);
|
||||
const int src_i_high = src_i_low + 1;
|
||||
const int src_j = threadIdx.x / 4;
|
||||
|
||||
const int src_laneid_low = src_i_low * 4 + src_j / 2;
|
||||
const int src_laneid_high = src_i_high * 4 + src_j / 2;
|
||||
|
||||
const int shift_low = ((src_j + 0) % 2) * 16;
|
||||
const int shift_high = ((src_j + 1) % 2) * 16;
|
||||
|
||||
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
|
||||
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
|
||||
|
||||
return ret_low | ret_high;
|
||||
}
|
||||
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct mma_A_I16K4 {
|
||||
static_assert(sizeof(T) == 4, "bad type size");
|
||||
|
@ -119,21 +157,14 @@ struct mma_A_I16K8 {
|
|||
}
|
||||
|
||||
__device__ __forceinline__ void transpose() {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(xi[0]) : "r"(xi[0]));
|
||||
int tmp = 0;
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(tmp) : "r"(xi[1]));
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(xi[1]) : "r"(xi[2]));
|
||||
xi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
|
||||
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
||||
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
||||
xi[2] = tmp;
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(xi[3]) : "r"(xi[3]));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
|
||||
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -350,16 +381,10 @@ struct mma_C_I16J8<half2> {
|
|||
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
||||
mma_B_J8K8<half2> mma_B;
|
||||
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * xi = (int *) x;
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(Bxi[0]) : "r"(xi[0]));
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "+r"(Bxi[1]) : "r"(xi[1]));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
||||
|
||||
return mma_B;
|
||||
}
|
||||
|
@ -417,15 +442,9 @@ struct mma_C_I16J8<float> {
|
|||
mma_B.x[0] = make_half2(x[0], x[1]);
|
||||
mma_B.x[1] = make_half2(x[2], x[3]);
|
||||
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
int * Bxi = (int *) mma_B.x;
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
|
||||
: "+r"(Bxi[0]) : );
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;"
|
||||
: "+r"(Bxi[1]) : );
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
||||
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
||||
|
||||
return mma_B;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue