Vectorize load instructions in dmmv f16 CUDA kernel

Replaces scalar with vector load instructions, which substantially
improves performance on NVIDIA HBM GPUs, e.g. gives a 1.27X overall
speedup for Meta-Llama-3-8B-Instruct-F16 BS1 inference evaluation on
H100 SXM 80GB HBM3. On GDDR GPUs, there is a slight (1.01X) speedup.
This commit is contained in:
Alan Gray 2024-10-08 05:00:26 -07:00
parent e7022064ab
commit 95c8b9c1b7

View file

@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx; const half * x = (const half *) vx;
// load 2 halfs into register in a single instruction
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
// automatic half -> float type cast if dfloat == float // automatic half -> float type cast if dfloat == float
v.x = x[ib + iqs + 0]; v.x = x_reg.x;
v.y = x[ib + iqs + 1]; v.y = x_reg.y;
} }
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
@ -476,13 +477,31 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
// matrix multiplication // matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_F16 #ifdef GGML_CUDA_F16
tmp += __hmul2(v, { if ( y_offset == 1 ) {
y[iybs + iqs + j/qr + 0], // load 2 dfloats into register in a single instruction
y[iybs + iqs + j/qr + y_offset] const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
}); tmp += __hmul2(v, {
y_reg.x;
y_reg.y;
});
}
else {
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
}
#else #else
tmp += v.x * y[iybs + iqs + j/qr + 0]; if ( y_offset == 1 ) {
tmp += v.y * y[iybs + iqs + j/qr + y_offset]; // load 2 dfloats into register in a single instruction
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
tmp += v.x * y_reg.x;
tmp += v.y * y_reg.y;
}
else {
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
}
#endif // GGML_CUDA_F16 #endif // GGML_CUDA_F16
} }
} }