From a2db989667393ae8e95a3a85f778763d97fa0e1e Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 18 Jun 2023 11:36:09 +0200 Subject: [PATCH] cleanup --- ggml-cuda.cu | 99 ++++++++++++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 561a4fbe4..9ebc57aff 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -53,29 +53,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); #ifdef GGML_CUDA_DMMV_F16 typedef half dfloat; // dequantize float typedef half2 dfloat2; - -static __device__ __forceinline__ void dadd_scalar_inplace(dfloat2 & a, const dfloat b) { - a = __hadd2(a, {b, b}); -} - -static __device__ __forceinline__ void dadd_vector_inplace(dfloat2 & a, const dfloat2 b) { - a = __hadd2(a, b); -} - -static __device__ __forceinline__ void dsub_scalar_inplace(dfloat2 & a, const dfloat b) { - a = __hsub2(a, {b, b}); -} - -static __device__ __forceinline__ void dmul_scalar_inplace(dfloat2 & a, const dfloat b) { - a = __hmul2(a, {b, b}); -} - -static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const dfloat2 b) { - a = __hmul2(a, b); -} - #else - typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_DMMV_F16 @@ -264,7 +242,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } -static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -274,14 +252,16 @@ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int v.x = vui & 0xF; v.y = vui >> 4; - v.x -= 8.0f; - v.y -= 8.0f; - - v.x *= d; - v.y *= d; +#ifdef GGML_CUDA_DMMV_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = x[ib].d; @@ -292,11 +272,16 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int v.x = vui & 0xF; v.y = vui >> 4; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else v.x = (v.x * d) + m; v.y = (v.y * d) + m; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -310,14 +295,16 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v.x -= 16.0f; - v.y -= 16.0f; - - v.x *= d; - v.y *= d; +#ifdef GGML_CUDA_DMMV_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = x[ib].d; @@ -332,14 +319,16 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v.x *= d; - v.y *= d; - - v.x += m; - v.y += m; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; @@ -347,8 +336,12 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); +#else v.x *= d; v.y *= d; +#endif // GGML_CUDA_DMMV_F16 } //================================== k-quants @@ -927,7 +920,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - dfloat tmp = 0.0f; // partial sum for thread in warp +// partial sum for each thread +#ifdef GGML_CUDA_DMMV_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_DMMV_F16 for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; @@ -941,14 +939,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, // process 2 vals per j iter // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val dfloat2 v; dequantize_kernel(vx, ib, iqs + j/qr, v); - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_DMMV_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else tmp += v.x * y[iybs + iqs + j/qr + 0]; tmp += v.y * y[iybs + iqs + j/qr + y_offset]; - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#endif // GGML_CUDA_DMMV_F16 } } @@ -960,7 +965,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } if (tid == 0) { +#ifdef GGML_CUDA_DMMV_F16 + dst[row] = tmp.x + tmp.y; +#else dst[row] = tmp; +#endif // GGML_CUDA_DMMV_F16 } }