From 1d9dd480ff6c2d2c59e9f33f87fc01325d3a4ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 14 Jun 2024 17:43:33 +0200 Subject: [PATCH] rever q2_K precision related changes --- ggml-cuda/mmq.cuh | 47 +++++++++++++++++------------------ ggml-cuda/quantize.cu | 58 +++++++++++-------------------------------- ggml-cuda/vecdotq.cuh | 13 +++++----- 3 files changed, 43 insertions(+), 75 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 774083249..6d57974fb 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -882,8 +882,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -895,7 +895,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], - &x_dm[i*(WARP_SIZE + 1) + k0], y_ds[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); + &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -911,7 +911,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( typedef mma_int_C_I16J8 mma_C; const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); @@ -944,10 +944,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C[2]; + mma_C Cd[2]; + mma_C Cm[2]; mma_B B[2]; float dB[mma_C::ne/2]; - float sB[mma_C::ne/2][2]; #pragma unroll for (int l = 0; l < mma_B::ne; ++l) { @@ -961,20 +961,21 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - const half2 tmp = y_ds[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; - dB[l] = __low2float(tmp); - - const int8_t * sBi = (const int8_t *) &tmp.y; - sB[l][0] = (127.0f/8.0f)*sBi[0]; - sB[l][1] = (127.0f/8.0f)*sBi[1]; + dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[0], B[0]); - C[1].mma_K4(A[1], B[1]); + Cd[0].mma_K4(A[0], B[0]); + Cd[1].mma_K4(A[1], B[1]); + + mma_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + Cm[0].mma_K4(A1, B[0]); + Cm[1].mma_K4(A1, B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2]; + sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2]; } } #else @@ -1870,26 +1871,24 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; -static int mmq_need_sum(const ggml_type type_x) { +static bool mmq_need_sum(const ggml_type type_x) { switch (type_x) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - return 1; + return true; case GGML_TYPE_Q5_0: - return 0; + return false; case GGML_TYPE_Q5_1: - return 1; + return true; case GGML_TYPE_Q8_0: - return 0; case GGML_TYPE_Q2_K: - return 2; case GGML_TYPE_Q3_K: - return 0; + return false; case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - return 1; + return true; case GGML_TYPE_Q6_K: - return 0; + return false; default: GGML_ASSERT(false); break; diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index 8d61d8bd6..b46786822 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,5 +1,4 @@ #include "quantize.cuh" -#include #include static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { @@ -38,7 +37,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { @@ -61,48 +60,24 @@ static __global__ void quantize_mmq_q8_1( amax = warp_reduce_max(amax); + float sum; + if (need_sum) { + sum = warp_reduce_sum(xi); + } + const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; - static_assert(need_sum >= 0 && need_sum <= 2, "Invalid need_sum value."); - if (need_sum == 0) { - if (iqs % QK8_1 != 0) { - return; - } - - ((float *) y[ib].ds)[iqs/QK8_1] = d; - } else if (need_sum == 1) { - const float sum = warp_reduce_sum(xi); - - if (iqs % QK8_1 != 0) { - return; - } + if (iqs % QK8_1 != 0) { + return; + } + if (need_sum) { y[ib].ds[iqs/QK8_1] = make_half2(d, sum); } else { - float sum = xi; - - // Calculate sum per 16 values: -#pragma unroll - for (int mask = 8; mask > 0; mask >>= 1) { - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } - - if (iqs % (QK8_1/2) != 0) { - return; - } - - int8_t * si = (int8_t *) &y[ib].ds[iqs/QK8_1].y; - const int tmp = roundf(amax == 0.0f ? 0.0f : -8*sum/amax); - si[(iqs % QK8_1)/(QK8_1/2)] = min(tmp, 127); - - if (iqs % QK8_1 != 0) { - return; - } - - reinterpret_cast(y[ib].ds[iqs/QK8_1].x) = d; + ((float *) y[ib].ds)[iqs/QK8_1] = d; } } @@ -129,14 +104,9 @@ void quantize_mmq_q8_1_cuda( const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, kx1, channels); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - const int need_sum = mmq_need_sum(type_x); - if (need_sum == 0) { - quantize_mmq_q8_1<0><<>>(x, vy, kx0, kx1, kx0_padded); - } else if (need_sum == 1) { - quantize_mmq_q8_1<1><<>>(x, vy, kx0, kx1, kx0_padded); - } else if (need_sum == 2) { - quantize_mmq_q8_1<2><<>>(x, vy, kx0, kx1, kx0_padded); + if (mmq_need_sum(type_x)) { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); } else { - GGML_ASSERT(false); + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); } } diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 6bf4d6b7a..3b12d6566 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -265,32 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( // contiguous u/y values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const half2 & ds8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; - const float d8 = __low2float(ds8); - const int8_t * s8i = (const int8_t *) &ds8.y; - #pragma unroll for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); int sumi_d = 0; + int sumi_m = 0; const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; - sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product + sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product + sumi_m = __dp4a(0x01010101, u[i], sumi_m); } sumf_d += dm2f.x * sumi_d; - sumf_m += dm2f.y * s8i[i0/(QI8_1/2)]; + sumf_m += dm2f.y * sumi_m; } - return d8*(sumf_d + (127.0f/8.0f)*sumf_m); + return d8*(sumf_d - sumf_m); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A