From ccf4901334b8abaeded589a75eb2ecb09c1b81dd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 17:39:25 +0300 Subject: [PATCH] k_quants: change Q5_K to be type 0 when QK_K = 64 Still needs AVX2 implementation --- ggml-cuda.cu | 32 +++++++------- ggml-metal.metal | 46 ++++++++++--------- k_quants.c | 113 +++++++++++++++++++++++++---------------------- k_quants.h | 5 ++- 4 files changed, 103 insertions(+), 93 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 840c0a863..a97d351a1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -164,11 +164,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, #ifdef GGML_QKK_64 typedef struct { - half d[2*QK_K/32]; // super-block scales/mins - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else typedef struct { half d; // super-block scale for quantized scales @@ -546,12 +547,14 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { #else const int tid = threadIdx.x; const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; float * y = yy + i*QK_K + tid; - y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1]; - y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3]; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); #endif } @@ -992,17 +995,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; const float * y = yy + i*QK_K + step; - const half2 * d = (const half2 *)x[i].d; - float2 df1 = __half22float2(d[0]); - float2 df2 = __half22float2(d[1]); + const float d = x[i].d; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { const uint8_t h = x[i].qh[in+j] >> im; - sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y) - + y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y) - + y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y) - + y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y); + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); } tmp += sum; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 0a170531b..e62fe6842 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -823,7 +823,8 @@ typedef struct { #if QK_K == 64 typedef struct { - half4 d; // super-block scales/mins + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; @@ -1062,20 +1063,21 @@ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, i #else for (int i = 0; i < nb; i++) { - const float4 d = (float4)x[i].d; + const float d = (float)x[i].d; device const uint8_t * ql = x[i].qs; device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; for (int l = 0; l < 8; ++l) { - y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1]; - y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1]; - y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1]; - y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1]; - y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3]; - y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3]; - y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3]; - y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3]; + y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); } y += QK_K; } @@ -1336,12 +1338,6 @@ kernel void kernel_mul_mat_q3_K_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const uint8_t m3 = 3; - const int8_t m4 = 4; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -1355,6 +1351,12 @@ kernel void kernel_mul_mat_q3_K_f32( #if QK_K == 256 + const uint8_t m3 = 3; + const int8_t m4 = 4; + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = tpitg.y; // expecting 16 const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 @@ -1682,18 +1684,18 @@ kernel void kernel_mul_mat_q5_K_f32( for (int i = tpitg.y; i < nb; i += tptg.y) { + const float d = (float)x[i].d; device const uint8_t * q = x[i].qs + il; device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; device const float * y = yy + i*QK_K + il; - const float4 d = (float4)x[i].d; - for (int l = 0; l < 4; ++l) { const uint8_t hl = h[l] >> im; - sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1]) - + y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1]) - + y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3]) - + y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]); + sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) + + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) + + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) + + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); } } #endif diff --git a/k_quants.c b/k_quants.c index 3d395e59f..a6221df9e 100644 --- a/k_quants.c +++ b/k_quants.c @@ -792,10 +792,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict assert(k % QK_K == 0); const int nb = k / QK_K; - uint8_t L[QK_K]; #if QK_K == 256 + uint8_t L[QK_K]; float mins[QK_K/32]; float scales[QK_K/32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; #endif for (int i = 0; i < nb; i++) { @@ -869,20 +872,30 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict ql += 32; } #else - for (int j = 0; j < QK_K/32; ++j) { - float min; - float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5); - y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); - y[i].d[2*j+1] = ggml_fp32_to_fp16(min); + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } } - for (int j = 0; j < QK_K/32; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); - l = MAX(0, MIN(31, l)); - L[32*j + ii] = l; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; } } @@ -938,17 +951,17 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int u1 <<= 2; u2 <<= 2; } #else - float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); - float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + float d = ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict s = x[i].scales; for (int l = 0; l < 8; ++l) { - y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1; - y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1; - y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1; - y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1; - y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2; - y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2; - y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2; - y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2; + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); } y += QK_K; #endif @@ -2751,19 +2764,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - float32x4_t acc1 = vdupq_n_f32(0.f); - float32x4_t acc2 = vdupq_n_f32(0.f); - for (int i = 0; i < nb; ++i) { - const float16x4_t s16 = vld1_f16(x[i].d); - const float32x4_t s32 = vmulq_n_f32(vcvt_f32_f16(s16), y[i].d); - //const int16x4_t bi16 = vld1_s16(y[i].bsums); - //const int32x4_t bi32 = vmovl_s16(vpadd_s16(bi16, bi16)); - //const float32x4_t bf32 = vcvtq_f32_s32(bi32); - //sumf -= (vgetq_lane_f32(s32, 1) * vgetq_lane_f32(bf32, 0) + vgetq_lane_f32(s32, 3) * vgetq_lane_f32(bf32, 1)); - // The above is slightly slower than just this: - sumf -= (vgetq_lane_f32(s32, 1) * (y[i].bsums[0] + y[i].bsums[1]) + vgetq_lane_f32(s32, 3) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + sumf -= 16.f * d * (sc[0] * y[i].bsums[0] + sc[1] * y[i].bsums[1] + sc[2] * y[i].bsums[2] + sc[3] * y[i].bsums[3]); const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -2787,34 +2793,35 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #if defined(__ARM_FEATURE_DOTPROD) - acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])), - vgetq_lane_f32(s32, 0)); - acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])), - vgetq_lane_f32(s32, 2)); + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + #else const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p01_16 = vaddq_s16(p0, p1); - const int32x4_t p01_32 = vaddq_s32(vmovl_s16(vget_low_s16(p01_16)), vmovl_s16(vget_high_s16(p01_16))); - acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(p01_32), vgetq_lane_f32(s32, 0)); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - const int16x8_t p02_16 = vaddq_s16(p2, p3); - const int32x4_t p02_32 = vaddq_s32(vmovl_s16(vget_low_s16(p02_16)), vmovl_s16(vget_high_s16(p02_16))); - acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(p02_32), vgetq_lane_f32(s32, 2)); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; #endif } - *s = vaddvq_f32(vaddq_f32(acc1, acc2)) + sumf; + *s = sumf; -#elif defined __AVX2__ +#elif defined z__AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m1 = _mm256_set1_epi16(1); @@ -2884,19 +2891,17 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri } for (int is = 0; is < 8; ++is) { uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0); + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); } - sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict sc = x[i].scales; - for (int j = 0; j < QK_K/32; ++j) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; diff --git a/k_quants.h b/k_quants.h index 6256ae167..6abe3d7b8 100644 --- a/k_quants.h +++ b/k_quants.h @@ -80,11 +80,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/ // Effectively 5.5 bits per weight #ifdef GGML_QKK_64 typedef struct { - ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales