rever q2_K precision related changes

This commit is contained in:
Johannes Gäßler 2024-06-14 17:43:33 +02:00
parent bff3a20944
commit 1d9dd480ff
3 changed files with 43 additions and 75 deletions

View file

@ -882,8 +882,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@ -895,7 +895,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
&x_dm[i*(WARP_SIZE + 1) + k0], y_ds[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
&x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
}
}
}
@ -911,7 +911,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
typedef mma_int_C_I16J8 mma_C;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const float * y_df = (const float *) y;
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
@ -944,10 +944,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
mma_C C[2];
mma_C Cd[2];
mma_C Cm[2];
mma_B B[2];
float dB[mma_C::ne/2];
float sB[mma_C::ne/2][2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
@ -961,20 +961,21 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
const half2 tmp = y_ds[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
dB[l] = __low2float(tmp);
const int8_t * sBi = (const int8_t *) &tmp.y;
sB[l][0] = (127.0f/8.0f)*sBi[0];
sB[l][1] = (127.0f/8.0f)*sBi[1];
dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
}
C[0].mma_K4(A[0], B[0]);
C[1].mma_K4(A[1], B[1]);
Cd[0].mma_K4(A[0], B[0]);
Cd[1].mma_K4(A[1], B[1]);
mma_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
Cm[0].mma_K4(A1, B[0]);
Cm[1].mma_K4(A1, B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2];
sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
}
}
#else
@ -1870,26 +1871,24 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
static int mmq_need_sum(const ggml_type type_x) {
static bool mmq_need_sum(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return 1;
return true;
case GGML_TYPE_Q5_0:
return 0;
return false;
case GGML_TYPE_Q5_1:
return 1;
return true;
case GGML_TYPE_Q8_0:
return 0;
case GGML_TYPE_Q2_K:
return 2;
case GGML_TYPE_Q3_K:
return 0;
return false;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
return 1;
return true;
case GGML_TYPE_Q6_K:
return 0;
return false;
default:
GGML_ASSERT(false);
break;

View file

@ -1,5 +1,4 @@
#include "quantize.cuh"
#include <cmath>
#include <cstdint>
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
@ -38,7 +37,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
template <int need_sum>
template <bool need_sum>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
@ -61,48 +60,24 @@ static __global__ void quantize_mmq_q8_1(
amax = warp_reduce_max(amax);
float sum;
if (need_sum) {
sum = warp_reduce_sum(xi);
}
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
static_assert(need_sum >= 0 && need_sum <= 2, "Invalid need_sum value.");
if (need_sum == 0) {
if (iqs % QK8_1 != 0) {
return;
}
((float *) y[ib].ds)[iqs/QK8_1] = d;
} else if (need_sum == 1) {
const float sum = warp_reduce_sum(xi);
if (iqs % QK8_1 != 0) {
return;
}
if (iqs % QK8_1 != 0) {
return;
}
if (need_sum) {
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
} else {
float sum = xi;
// Calculate sum per 16 values:
#pragma unroll
for (int mask = 8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
}
if (iqs % (QK8_1/2) != 0) {
return;
}
int8_t * si = (int8_t *) &y[ib].ds[iqs/QK8_1].y;
const int tmp = roundf(amax == 0.0f ? 0.0f : -8*sum/amax);
si[(iqs % QK8_1)/(QK8_1/2)] = min(tmp, 127);
if (iqs % QK8_1 != 0) {
return;
}
reinterpret_cast<half&>(y[ib].ds[iqs/QK8_1].x) = d;
((float *) y[ib].ds)[iqs/QK8_1] = d;
}
}
@ -129,14 +104,9 @@ void quantize_mmq_q8_1_cuda(
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
const int need_sum = mmq_need_sum(type_x);
if (need_sum == 0) {
quantize_mmq_q8_1<0><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
} else if (need_sum == 1) {
quantize_mmq_q8_1<1><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
} else if (need_sum == 2) {
quantize_mmq_q8_1<2><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
if (mmq_need_sum(type_x)) {
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
} else {
GGML_ASSERT(false);
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
}
}

View file

@ -265,32 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
// contiguous u/y values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const half2 & ds8) {
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
float sumf_d = 0.0f;
float sumf_m = 0.0f;
const float d8 = __low2float(ds8);
const int8_t * s8i = (const int8_t *) &ds8.y;
#pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
int sumi_d = 0;
int sumi_m = 0;
const int vi0 = v[i0/(QI8_1/2)];
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_m = __dp4a(0x01010101, u[i], sumi_m);
}
sumf_d += dm2f.x * sumi_d;
sumf_m += dm2f.y * s8i[i0/(QI8_1/2)];
sumf_m += dm2f.y * sumi_m;
}
return d8*(sumf_d + (127.0f/8.0f)*sumf_m);
return d8*(sumf_d - sumf_m);
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A