CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (#7860)
This commit is contained in:
parent
c2ce6c47e4
commit
bdcb8f4222
2 changed files with 360 additions and 6 deletions
|
@ -1,5 +1,27 @@
|
|||
#include "common.cuh"
|
||||
|
||||
struct mma_int_A_I16K4 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_A_I16K8 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 8;
|
||||
|
@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
|
|||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K4 {
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 4;
|
||||
static constexpr int ne = 1;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int /* l */) {
|
||||
const int ret = threadIdx.x % K;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K8 {
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 8;
|
||||
|
@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
|
|||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
#else
|
||||
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue