dequantize + matrix multiplication CUDA kernels
This commit is contained in:
parent
d3494bb86b
commit
2d55023143
3 changed files with 519 additions and 29 deletions
|
@ -71,6 +71,7 @@ option(LLAMA_CUBLAS "llama: use cuBLAS"
|
|||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
||||
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
|
||||
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
|
||||
option(LLAMA_CUDA_DMM "llama: use dequantize mul mat CUDA kernels" OFF)
|
||||
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
|
||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||
option(LLAMA_METAL "llama: use Metal" OFF)
|
||||
|
@ -251,6 +252,9 @@ if (LLAMA_CUBLAS)
|
|||
if (LLAMA_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_DMMV_F16)
|
||||
endif()
|
||||
if (LLAMA_CUDA_DMM)
|
||||
add_compile_definitions(GGML_CUDA_DMM)
|
||||
endif()
|
||||
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
|
|
3
Makefile
3
Makefile
|
@ -179,6 +179,9 @@ endif # LLAMA_CUDA_DMMV_Y
|
|||
ifdef LLAMA_CUDA_DMMV_F16
|
||||
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
|
||||
endif # LLAMA_CUDA_DMMV_F16
|
||||
ifdef LLAMA_CUDA_DMM
|
||||
NVCCFLAGS += -DGGML_CUDA_DMM
|
||||
endif # LLAMA_CUDA_DMM
|
||||
ifdef LLAMA_CUDA_KQUANTS_ITER
|
||||
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
||||
else
|
||||
|
|
541
ggml-cuda.cu
541
ggml-cuda.cu
|
@ -58,7 +58,8 @@ typedef float dfloat; // dequantize float
|
|||
typedef float2 dfloat2;
|
||||
#endif //GGML_CUDA_DMMV_F16
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_2_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef float (*dequantize_1_kernel_t)(const void * vx, const int i);
|
||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
@ -280,7 +281,257 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
|||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static __device__ __forceinline__ float dequantize_1_f32(const void * vx, const int i){
|
||||
const float * x = (const float *) vx;
|
||||
|
||||
return x[i];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_f16(const void * vx, const int i){
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
return __half2float(x[i]);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q4_0(const void * vx, const int i){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const int ib = i / QK4_0;
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const int iqs0 = i % QK4_0;
|
||||
const int shift = iqs0 / (QK4_0/QR4_0);
|
||||
const int iqs = iqs0 - shift * (QK4_0/QR4_0);
|
||||
|
||||
int vi = x[ib].qs[iqs];
|
||||
|
||||
vi >>= 4 * shift;
|
||||
vi &= 0xF;
|
||||
|
||||
return (vi - 8) * d;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q4_1(const void * vx, const int i){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
const int ib = i / QK4_1;
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float m = x[ib].m;
|
||||
|
||||
const int iqs0 = i % QK4_1;
|
||||
const int shift = iqs0 / (QK4_1/QR4_1);
|
||||
const int iqs = iqs0 - shift * (QK4_1/QR4_1);
|
||||
|
||||
int vi = x[ib].qs[iqs];
|
||||
|
||||
vi >>= 4 * shift;
|
||||
vi &= 0xF;
|
||||
|
||||
return vi * d + m;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q5_0(const void * vx, const int i){
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
const int ib = i / QK4_0;
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
||||
const int iqs0 = i % QK5_0;
|
||||
const int shift = iqs0 / (QK5_0/QR5_0);
|
||||
const int not_shift = shift ^ 1;
|
||||
const int iqs = iqs0 - shift * (QK5_0/QR5_0);
|
||||
|
||||
int vi = x[ib].qs[iqs];
|
||||
vi >>= 4 * shift;
|
||||
vi &= 0xF;
|
||||
|
||||
const int xh = ((qh >> (iqs + 12*shift)) << not_shift*4) & 0x10;
|
||||
vi |= xh;
|
||||
|
||||
return (vi - 16) * d;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q5_1(const void * vx, const int i){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
const int ib = i / QK4_0;
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float m = x[ib].m;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||
|
||||
const int iqs0 = i % QK5_0;
|
||||
const int shift = iqs0 / (QK5_0/QR5_0);
|
||||
const int not_shift = shift ^ 1;
|
||||
const int iqs = iqs0 - shift * (QK5_0/QR5_0);
|
||||
|
||||
int vi = x[ib].qs[iqs];
|
||||
vi >>= 4 * shift;
|
||||
vi &= 0xF;
|
||||
|
||||
const int xh = ((qh >> (iqs + 12*shift)) << not_shift*4) & 0x10;
|
||||
vi |= xh;
|
||||
|
||||
return vi * d + m;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q8_0(const void * vx, const int i){
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
const int ib = i / QK8_0;
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const int iqs = i % QK8_0;
|
||||
|
||||
const float v = x[ib].qs[iqs];
|
||||
|
||||
return v * d;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q2_K(const void * vx, const int i){
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
const int ib = i / QK_K;
|
||||
|
||||
const int iy = i % QK_K;
|
||||
const int n = iy / (QK_K/2);
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float dmin = x[ib].dmin;
|
||||
|
||||
const int iqs = iy % (QK_K/8) + n * (QK_K/8);
|
||||
const int qs_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8));
|
||||
const int qs = (x[ib].qs[iqs] >> qs_shift) & 3;
|
||||
|
||||
const int isc = iy / (QK_K/16);
|
||||
const int sc = x[ib].scales[isc];
|
||||
|
||||
const float dl = d * (sc & 0xF);
|
||||
const float ml = dmin * (sc >> 4);
|
||||
|
||||
return dl * qs - ml;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q3_K(const void * vx, const int i){
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
const int ib = i / QK_K;
|
||||
|
||||
const int iy = i % QK_K;
|
||||
const int n = iy / (QK_K/2);
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const int iqs = iy % (QK_K/8) + n * (QK_K/8);
|
||||
const int qs_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8));
|
||||
const int qs = (x[ib].qs[iqs] >> qs_shift) & 3;
|
||||
|
||||
const int ih = iy % (QK_K/8);
|
||||
const int ih_shift = iy / (QK_K/8);
|
||||
const int h = x[ib].hmask[ih] & (1 << ih_shift) ? 0 : 4;
|
||||
|
||||
const int q = qs - h;
|
||||
|
||||
const int isc = iy / (QK_K/16);
|
||||
|
||||
const int isc_low = isc % (QK_K/32);
|
||||
const int sc_shift_low = 4 * (isc / (QK_K/32));
|
||||
const int sc_low = (x[ib].scales[isc_low] >> sc_shift_low) & 0xF;
|
||||
|
||||
const int isc_high = isc % (QK_K/64);
|
||||
const int sc_shift_high = 2 * (isc / (QK_K/64));
|
||||
const int sc_high = ((x[ib].scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
|
||||
|
||||
const int sc = (sc_low | sc_high) - 32;
|
||||
|
||||
return d * sc * q;
|
||||
}
|
||||
|
||||
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||
if (j < 4) {
|
||||
d = q[j] & 63; m = q[j + 4] & 63;
|
||||
} else {
|
||||
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q4_K(const void * vx, const int i){
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
const int ib = i / QK_K;
|
||||
|
||||
const int iy = i % QK_K;
|
||||
const int j = iy / (QK_K/4);
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float dmin = x[ib].dmin;
|
||||
|
||||
const int iqs = iy % (QK_K/8) + j * (QK_K/8);
|
||||
const int qs_shift = 4 * ((iy % (QK_K/4)) / (QK_K/8));
|
||||
const int qs = (x[ib].qs[iqs] >> qs_shift) & 0xF;
|
||||
|
||||
const int isc = iy / (QK_K/8);
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(isc, x[ib].scales, sc, m);
|
||||
|
||||
return d * sc * qs - dmin * m;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q5_K(const void * vx, const int i){
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
const int ib = i / QK_K;
|
||||
|
||||
const int iy = i % QK_K;
|
||||
const int j = iy / (QK_K/4);
|
||||
|
||||
const float d = x[ib].d;
|
||||
const float dmin = x[ib].dmin;
|
||||
|
||||
const int iqs = iy % (QK_K/8) + j * (QK_K/8);
|
||||
const int qs_shift = 4 * ((iy % (QK_K/4)) / (QK_K/8));
|
||||
const int qs = (x[ib].qs[iqs] >> qs_shift) & 0xF;
|
||||
|
||||
const int isc = iy / (QK_K/8);
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(isc, x[ib].scales, sc, m);
|
||||
|
||||
const int iqh = iy % (QK_K/8);
|
||||
const int qh_shift = iy / (QK_K/8);
|
||||
const int qh = 16 * ((x[ib].qh[iqh] >> qh_shift) & 1);
|
||||
|
||||
const int q = qs + qh;
|
||||
|
||||
return d * sc * q - dmin * m;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float dequantize_1_q6_K(const void * vx, const int i){
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
const int ib = i / QK_K;
|
||||
|
||||
const int iy = i % QK_K;
|
||||
const int n = iy / (QK_K/2);
|
||||
|
||||
const float d = x[ib].d;
|
||||
|
||||
const int iql = iy % (QK_K/4) + n * (QK_K/4);
|
||||
const int ql_shift = 4 * ((iy % (QK_K/2)) / (QK_K/4));
|
||||
const int ql = (x[ib].ql[iql] >> ql_shift) & 0xF;
|
||||
|
||||
const int iqh = iy % (QK_K/8) + n * (QK_K/8);
|
||||
const int qh_shift = 2 * ((iy % (QK_K/2)) / (QK_K/8));
|
||||
const int qh = (((x[ib].qh[iqh] >> qh_shift) & 3) << 4);
|
||||
|
||||
const int q = (ql | qh) - 32;
|
||||
|
||||
const int isc = iy / (QK_K/16);
|
||||
const int sc = x[ib].scales[isc];
|
||||
|
||||
return d * sc * q;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_2_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
@ -299,7 +550,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
|
|||
#endif // GGML_CUDA_DMMV_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static __device__ __forceinline__ void dequantize_2_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
@ -319,7 +570,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
|
|||
#endif // GGML_CUDA_DMMV_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static __device__ __forceinline__ void dequantize_2_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
@ -342,7 +593,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
|
|||
#endif // GGML_CUDA_DMMV_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static __device__ __forceinline__ void dequantize_2_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
@ -366,7 +617,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
|
|||
#endif // GGML_CUDA_DMMV_F16
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static __device__ __forceinline__ void dequantize_2_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const dfloat d = x[ib].d;
|
||||
|
@ -470,17 +721,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
|||
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||
if (j < 4) {
|
||||
d = q[j] & 63; m = q[j + 4] & 63;
|
||||
} else {
|
||||
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
|
@ -1153,7 +1393,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
|||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
template <int qk, int qr, dequantize_2_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||
|
||||
|
@ -1174,7 +1414,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
|||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
template <int qk, int qr, dequantize_2_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
|
@ -1243,6 +1483,73 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
|||
}
|
||||
}
|
||||
|
||||
template <dequantize_1_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat(
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst) {
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int ncols_dst = ncols_y;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int row_dst_0 = blockIdx.x*WARP_SIZE;
|
||||
const int row_x_0 = row_dst_0;
|
||||
const int row_dst = row_dst_0 + tid;
|
||||
|
||||
const int col_dst_0 = blockIdx.y*WARP_SIZE;
|
||||
const int col_y_0 = col_dst_0;
|
||||
|
||||
__shared__ float tile_x[WARP_SIZE][WARP_SIZE + 1];
|
||||
__shared__ float tile_y[WARP_SIZE][WARP_SIZE];
|
||||
float sum[WARP_SIZE] = {0.0f};
|
||||
|
||||
for (int col_x_0 = 0; col_x_0 < ncols_x; col_x_0 += WARP_SIZE) {
|
||||
const int row_y_0 = col_x_0;
|
||||
|
||||
const int col_x_tile = min(col_x_0 + tid, ncols_x-1);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < WARP_SIZE; ++j) {
|
||||
const int row_x_tile = min(row_x_0 + j, nrows_x-1);
|
||||
tile_x[j][tid] = dequantize_kernel(vx, row_x_tile*ncols_x + col_x_tile);
|
||||
}
|
||||
|
||||
const int row_y_tile = min(row_y_0 + tid, nrows_y-1);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_SIZE; ++i) {
|
||||
const int col_y_tile = min(col_y_0 + i, ncols_y-1);
|
||||
tile_y[i][tid] = y[col_y_tile*nrows_y + row_y_tile];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_SIZE; ++i) {
|
||||
const float xi = tile_x[tid][i];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < WARP_SIZE; ++j) {
|
||||
const float yi = tile_y[j][i];
|
||||
sum[j] += xi*yi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (row_dst >= nrows_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < WARP_SIZE; ++j) {
|
||||
const int col_dst_j = col_dst_0 + j;
|
||||
|
||||
if (col_dst_j >= ncols_dst) {
|
||||
break;
|
||||
}
|
||||
|
||||
dst[col_dst_j*nrows_dst + row_dst] = sum[j];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
|
@ -1491,27 +1798,27 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
|||
|
||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
dequantize_block<QK4_0, QR4_0, dequantize_2_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
dequantize_block<QK4_1, QR4_1, dequantize_2_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
dequantize_block<QK5_0, QR5_0, dequantize_2_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
dequantize_block<QK5_1, QR5_1, dequantize_2_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
dequantize_block<QK8_0, QR8_0, dequantize_2_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
|
@ -1560,7 +1867,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_2_q4_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -1569,7 +1876,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_2_q4_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -1578,7 +1885,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_2_q5_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -1587,7 +1894,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_2_q5_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -1596,7 +1903,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
|
|||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_2_q8_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
|
@ -1685,6 +1992,102 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_f32><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_f16_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_f16><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q4_0><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q4_1><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q5_0><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q5_1><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q8_0><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_dequantize_mul_mat_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
|
||||
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
dequantize_mul_mat<dequantize_1_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
|
||||
const dim3 block_nums(1, nrows_x, nchannels_x);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
|
@ -1848,9 +2251,11 @@ void ggml_init_cublas() {
|
|||
// create main stream
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
|
||||
|
||||
#ifndef GGML_CUDA_DMM
|
||||
// create cublas handle
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
#endif // GGML_CUDA_DMM
|
||||
}
|
||||
|
||||
// configure logging to stdout
|
||||
|
@ -2140,6 +2545,80 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_dequantize_mul_mat(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
cudaStream_t & cudaStream_main){
|
||||
|
||||
GGML_ASSERT(src0_ddq_i != nullptr);
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
|
||||
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
ggml_dequantize_mul_mat_f32_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_dequantize_mul_mat_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
ggml_dequantize_mul_mat_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
ggml_dequantize_mul_mat_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
ggml_dequantize_mul_mat_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
ggml_dequantize_mul_mat_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
ggml_dequantize_mul_mat_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
ggml_dequantize_mul_mat_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
ggml_dequantize_mul_mat_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_dequantize_mul_mat_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
ggml_dequantize_mul_mat_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_dequantize_mul_mat_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src0_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||
|
@ -2682,7 +3161,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
|
||||
} else {
|
||||
#ifdef GGML_CUDA_DMM
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat, false, false);
|
||||
#else
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||
#endif // GGML_CUDA_DMM
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue