diff --git a/Makefile b/Makefile index afd06e0a6..5dd676fad 100644 --- a/Makefile +++ b/Makefile @@ -169,6 +169,9 @@ ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 endif # LLAMA_CUDA_DMMV_Y +ifdef LLAMA_CUDA_DMMV_F16 + NVCCFLAGS += -DGGML_CUDA_DMMV_F16 +endif # LLAMA_CUDA_DMMV_F16 ifdef LLAMA_CUDA_KQUANTS_ITER NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) else diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0a7f1d956..b106fd81f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -50,6 +50,28 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 +#ifdef GGML_CUDA_DMMV_F16 +typedef half dfloat; // dequantize float +typedef half2 dfloat2; + +static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) { + a = __hadd2(a, {b, b}); +} + +static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) { + a = __hsub2(a, {b, b}); +} + +static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) { + a = __hmul2(a, {b, b}); +} + +static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) { + a = __hmul2(a, b); +} + +#else + typedef float dfloat; // dequantize float typedef float2 dfloat2; @@ -63,11 +85,17 @@ static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) a.y -= b; } -static __device__ __forceinline__ void dmul_inplace(dfloat2 & a, const dfloat b) { +static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) { a.x *= b; a.y *= b; } +static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) { + a.x *= b.x; + a.y *= b.y; +} +#endif //GGML_CUDA_DMMV_F16 + typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); @@ -263,7 +291,7 @@ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int v.y = vui >> 4; dsub_inplace(v, 8.0f); - dmul_inplace(v, d); + dmul_scalar_inplace(v, d); } static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -277,7 +305,7 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int v.x = vui & 0xF; v.y = vui >> 4; - dmul_inplace(v, d); + dmul_scalar_inplace(v, d); dadd_inplace(v, m); } @@ -296,7 +324,7 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int v.y = ((x[ib].qs[iqs] >> 4) | xh_1); dsub_inplace(v, 16.0f); - dmul_inplace(v, d); + dmul_scalar_inplace(v, d); } static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -314,7 +342,7 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - dmul_inplace(v, d); + dmul_scalar_inplace(v, d); dadd_inplace(v, m); } @@ -326,7 +354,7 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; - dmul_inplace(v, d); + dmul_scalar_inplace(v, d); } //================================== k-quants @@ -904,7 +932,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - float tmp = 0.0f; // partial sum for thread in warp + dfloat tmp = 0.0f; // partial sum for thread in warp for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; @@ -923,8 +951,9 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val // matrix multiplication - tmp += v.x * __half2float(y[iybs + iqs + j/qr + 0]); - tmp += v.y * __half2float(y[iybs + iqs + j/qr + y_offset]); + dfloat2 yi = {y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset]}; + tmp += v.x * yi.x; + tmp += v.y * yi.y; // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 } }