From 656c1ab30284c086f49b94035653c44ee3f3176e Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 29 Jul 2023 08:28:14 +0200 Subject: [PATCH] DMMV_F16 -> F16 --- Makefile | 5 ++++- ggml-cuda.cu | 52 ++++++++++++++++++++++++++-------------------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index f74e7a7df..3d1fff8e5 100644 --- a/Makefile +++ b/Makefile @@ -220,8 +220,11 @@ else ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 endif # LLAMA_CUDA_MMV_Y +ifdef LLAMA_CUDA_F16 + NVCCFLAGS += -DGGML_CUDA_F16 +endif # LLAMA_CUDA_F16 ifdef LLAMA_CUDA_DMMV_F16 - NVCCFLAGS += -DGGML_CUDA_DMMV_F16 + NVCCFLAGS += -DGGML_CUDA_F16 endif # LLAMA_CUDA_DMMV_F16 ifdef LLAMA_CUDA_KQUANTS_ITER NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 068772a57..3f741dd7d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -52,13 +52,13 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float typedef half2 dfloat2; #else typedef float dfloat; // dequantize float typedef float2 dfloat2; -#endif //GGML_CUDA_DMMV_F16 +#endif //GGML_CUDA_F16 static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment @@ -400,13 +400,13 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hsub2(v, {8.0f, 8.0f}); v = __hmul2(v, {d, d}); #else v.x = (v.x - 8.0f) * d; v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -420,13 +420,13 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); v = __hadd2(v, {m, m}); #else v.x = (v.x * d) + m; v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -443,13 +443,13 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hsub2(v, {16.0f, 16.0f}); v = __hmul2(v, {d, d}); #else v.x = (v.x - 16.0f) * d; v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -467,13 +467,13 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); v = __hadd2(v, {m, m}); #else v.x = (v.x * d) + m; v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -484,12 +484,12 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); #else v.x *= d; v.y *= d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } //================================== k-quants @@ -1441,14 +1441,14 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( int sumi = __dp4a(vi0, ui0, 0); sumi = __dp4a(vi1, ui1, sumi); -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 const half2 tmp = __hmul2(dm4, ds8); const float d4d8 = __half2float(tmp.x); const float m4s8 = __half2float(tmp.y); #else const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x); const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y); -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 // scale second part of sum by QI8_1/QR4_1 to compensate for multiple threads adding it return sumi * d4d8 + m4s8 / (QI8_1 / QR4_1); @@ -1598,14 +1598,14 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( vi1 |= (qh << 9) & 0x10000000; // 19 -> 28 sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 const half2 tmp = __hmul2(dm5, ds8); const float d5d8 = __half2float(tmp.x); const float m5s8 = __half2float(tmp.y); #else const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x); const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y); -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 return sumi*d5d8 + m5s8/QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block @@ -2603,11 +2603,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons const int y_offset = qr == 1 ? 1 : qk/2; // partial sum for each thread -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics #else float tmp = 0.0f; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; @@ -2627,7 +2627,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons // matrix multiplication // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 tmp += __hmul2(v, { y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset] @@ -2635,7 +2635,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons #else tmp += v.x * y[iybs + iqs + j/qr + 0]; tmp += v.y * y[iybs + iqs + j/qr + y_offset]; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } } @@ -2646,11 +2646,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } if (tid == 0) { -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 dst[row] = tmp.x + tmp.y; #else dst[row] = tmp; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } } @@ -3984,7 +3984,7 @@ inline void ggml_cuda_op_mul_mat_vec( ggml_cuda_pool_free(src1_q8_1, as); } else { // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 size_t ash; dfloat * src1_dfloat = nullptr; // dfloat == half @@ -4000,7 +4000,7 @@ inline void ggml_cuda_op_mul_mat_vec( } #else dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 switch (src0->type) { case GGML_TYPE_Q4_0: @@ -4041,11 +4041,11 @@ inline void ggml_cuda_op_mul_mat_vec( break; } -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 if (src1_convert_f16) { ggml_cuda_pool_free(src1_dfloat, ash); } -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } (void) src1;