From ff0efc747d109b7f255a3318fd0c3d0812d198f4 Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 13 Apr 2023 14:54:44 +0900 Subject: [PATCH] add Q4_2 --- ggml.c | 625 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 623 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index a26b4853f..2081e26fd 100644 --- a/ggml.c +++ b/ggml.c @@ -411,6 +411,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // #define QK 32 +#define QK128 128 // AVX routines provided by GH user Const-me // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 @@ -502,6 +503,16 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +// method 6 +// blocks of QK elements for GPTQ +// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) +typedef struct { + float d; + float m; + uint8_t qs[QK128 / 2]; // nibbles / quants +} block_q4_2; +static_assert(sizeof(block_q4_2) == sizeof(float) * 2 + QK128 / 2, "wrong q4_2 block size/padding"); + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { assert(k % QK == 0); @@ -954,6 +965,267 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int #endif } +static void quantize_row_q4_2_reference(const float * restrict x, void * restrict vy, int k) { + assert(k % QK128 == 0); + const int nb = k / QK128; + + block_q4_2 * restrict y = vy; + + uint8_t pp[QK128/2]; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK128; l++) { + const float v = x[i*QK128 + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + y[i].m = min; + + for (int l = 0; l < QK128; l += 2) { + const float v0 = (x[i*QK128 + l + 0] - min)*id; + const float v1 = (x[i*QK128 + l + 1] - min)*id; + + const uint8_t vi0 = roundf(v0); + const uint8_t vi1 = roundf(v1); + + assert(vi0 < 16); + assert(vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(y[i].qs, pp, sizeof(pp)); + } +} + +static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + block_q4_2 * restrict y = vy; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + __m256 v4 = _mm256_loadu_ps( x + 32 ); + __m256 v5 = _mm256_loadu_ps( x + 40 ); + __m256 v6 = _mm256_loadu_ps( x + 48 ); + __m256 v7 = _mm256_loadu_ps( x + 56 ); + __m256 v8 = _mm256_loadu_ps( x + 64 ); + __m256 v9 = _mm256_loadu_ps( x + 72 ); + __m256 v10 = _mm256_loadu_ps( x + 80 ); + __m256 v11 = _mm256_loadu_ps( x + 88 ); + __m256 v12 = _mm256_loadu_ps( x + 96 ); + __m256 v13 = _mm256_loadu_ps( x + 104 ); + __m256 v14 = _mm256_loadu_ps( x + 112 ); + __m256 v15 = _mm256_loadu_ps( x + 120 ); + x += 128; + + // Compute max for the block + __m256 vmax; + vmax = _mm256_max_ps( v0, v1 ); + vmax = _mm256_max_ps( vmax, v2 ); + vmax = _mm256_max_ps( vmax, v3 ); + vmax = _mm256_max_ps( vmax, v4 ); + vmax = _mm256_max_ps( vmax, v5 ); + vmax = _mm256_max_ps( vmax, v6 ); + vmax = _mm256_max_ps( vmax, v7 ); + vmax = _mm256_max_ps( vmax, v8 ); + vmax = _mm256_max_ps( vmax, v9 ); + vmax = _mm256_max_ps( vmax, v10 ); + vmax = _mm256_max_ps( vmax, v11 ); + vmax = _mm256_max_ps( vmax, v12 ); + vmax = _mm256_max_ps( vmax, v13 ); + vmax = _mm256_max_ps( vmax, v14 ); + vmax = _mm256_max_ps( vmax, v15 ); + + __m128 max16 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) ); + max16 = _mm_max_ps( max16, _mm_movehl_ps( max16, max16 ) ); + max16 = _mm_max_ss( max16, _mm_movehdup_ps( max16 ) ); + const float maxScalar = _mm_cvtss_f32( max16 ); + + // Compute min for the block + __m256 vmin; + vmin = _mm256_min_ps( v0, v1 ); + vmin = _mm256_min_ps( vmin, v2 ); + vmin = _mm256_min_ps( vmin, v3 ); + vmin = _mm256_min_ps( vmin, v4 ); + vmin = _mm256_min_ps( vmin, v5 ); + vmin = _mm256_min_ps( vmin, v6 ); + vmin = _mm256_min_ps( vmin, v7 ); + vmin = _mm256_min_ps( vmin, v8 ); + vmin = _mm256_min_ps( vmin, v9 ); + vmin = _mm256_min_ps( vmin, v10 ); + vmin = _mm256_min_ps( vmin, v11 ); + vmin = _mm256_min_ps( vmin, v12 ); + vmin = _mm256_min_ps( vmin, v13 ); + vmin = _mm256_min_ps( vmin, v14 ); + vmin = _mm256_min_ps( vmin, v15 ); + + __m128 min16 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) ); + min16 = _mm_min_ps( min16, _mm_movehl_ps( min16, min16 ) ); + min16 = _mm_min_ss( min16, _mm_movehdup_ps( min16 ) ); + const float minScalar = _mm_cvtss_f32( min16 ); + + // Quantize these floats + const float d = (maxScalar - minScalar) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].m = minScalar; + y[i].d = d; + + // x = (x-min)*id + const __m256 mul = _mm256_set1_ps( id ); + const __m256 off = _mm256_set1_ps( minScalar ); + v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul ); + v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul ); + v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul ); + v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul ); + v4 = _mm256_mul_ps( _mm256_sub_ps( v4, off ), mul ); + v5 = _mm256_mul_ps( _mm256_sub_ps( v5, off ), mul ); + v6 = _mm256_mul_ps( _mm256_sub_ps( v6, off ), mul ); + v7 = _mm256_mul_ps( _mm256_sub_ps( v7, off ), mul ); + v8 = _mm256_mul_ps( _mm256_sub_ps( v8, off ), mul ); + v9 = _mm256_mul_ps( _mm256_sub_ps( v9, off ), mul ); + v10 = _mm256_mul_ps( _mm256_sub_ps( v10, off ), mul ); + v11 = _mm256_mul_ps( _mm256_sub_ps( v11, off ), mul ); + v12 = _mm256_mul_ps( _mm256_sub_ps( v12, off ), mul ); + v13 = _mm256_mul_ps( _mm256_sub_ps( v13, off ), mul ); + v14 = _mm256_mul_ps( _mm256_sub_ps( v14, off ), mul ); + v15 = _mm256_mul_ps( _mm256_sub_ps( v15, off ), mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + v4 = _mm256_round_ps( v4, _MM_ROUND_NEAREST ); + v5 = _mm256_round_ps( v5, _MM_ROUND_NEAREST ); + v6 = _mm256_round_ps( v6, _MM_ROUND_NEAREST ); + v7 = _mm256_round_ps( v7, _MM_ROUND_NEAREST ); + v8 = _mm256_round_ps( v8, _MM_ROUND_NEAREST ); + v9 = _mm256_round_ps( v9, _MM_ROUND_NEAREST ); + v10 = _mm256_round_ps( v10, _MM_ROUND_NEAREST ); + v11 = _mm256_round_ps( v11, _MM_ROUND_NEAREST ); + v12 = _mm256_round_ps( v12, _MM_ROUND_NEAREST ); + v13 = _mm256_round_ps( v13, _MM_ROUND_NEAREST ); + v14 = _mm256_round_ps( v14, _MM_ROUND_NEAREST ); + v15 = _mm256_round_ps( v15, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + __m256i i4 = _mm256_cvtps_epi32( v4 ); + __m256i i5 = _mm256_cvtps_epi32( v5 ); + __m256i i6 = _mm256_cvtps_epi32( v6 ); + __m256i i7 = _mm256_cvtps_epi32( v7 ); + __m256i i8 = _mm256_cvtps_epi32( v8 ); + __m256i i9 = _mm256_cvtps_epi32( v9 ); + __m256i i10 = _mm256_cvtps_epi32( v10 ); + __m256i i11 = _mm256_cvtps_epi32( v11 ); + __m256i i12 = _mm256_cvtps_epi32( v12 ); + __m256i i13 = _mm256_cvtps_epi32( v13 ); + __m256i i14 = _mm256_cvtps_epi32( v14 ); + __m256i i15 = _mm256_cvtps_epi32( v15 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + i4 = _mm256_packs_epi32( i4, i5 ); + i6 = _mm256_packs_epi32( i6, i7 ); + i8 = _mm256_packs_epi32( i8, i9 ); + i10 = _mm256_packs_epi32( i10, i11 ); + i12 = _mm256_packs_epi32( i12, i13 ); + i14 = _mm256_packs_epi32( i14, i15 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + i4 = _mm256_packs_epi16( i4, i6 ); + i8 = _mm256_packs_epi16( i8, i10 ); + i12 = _mm256_packs_epi16( i12, i14 ); + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + i4 = _mm256_permutevar8x32_epi32( i4, perm ); + i8 = _mm256_permutevar8x32_epi32( i8, perm ); + i12 = _mm256_permutevar8x32_epi32( i12, perm ); + + // Compress the vector into 4 bit/value, and store + __m128i res = packNibbles( i0 ); + _mm_storeu_si128( ( __m128i* )y[i].qs, res ); + + res = packNibbles( i4 ); + _mm_storeu_si128( (( __m128i* )y[i].qs) + 1, res ); + + res = packNibbles( i8 ); + _mm_storeu_si128( (( __m128i* )y[i].qs) + 2, res ); + + res = packNibbles( i12 ); + _mm_storeu_si128( (( __m128i* )y[i].qs) + 3, res ); + } +#elif __ARM_NEON + for (int i = 0; i < nb; i++) { + float32x4_t srcv[32]; + float32x4_t minv[32]; + float32x4_t maxv[32]; + + for (int l = 0; l < 32; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l); + + for (int l = 0; l < 16; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); + for (int l = 0; l < 8; l++) minv[4*l] = vminq_f32(srcv[4*l], srcv[4*l + 2]); + for (int l = 0; l < 4; l++) minv[8*l] = vminq_f32(srcv[8*l], srcv[8*l + 4]); + for (int l = 0; l < 2; l++) minv[16*l] = vminq_f32(minv[16*l], minv[16*l + 8]); + for (int l = 0; l < 1; l++) minv[32*l] = vminq_f32(minv[32*l], minv[32*l + 16]); + + for (int l = 0; l < 16; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]); + for (int l = 0; l < 8; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]); + for (int l = 0; l < 4; l++) maxv[8*l] = vmaxq_f32(srcv[8*l], srcv[8*l + 4]); + for (int l = 0; l < 2; l++) maxv[16*l] = vmaxq_f32(maxv[16*l], maxv[16*l + 8]); + for (int l = 0; l < 1; l++) maxv[32*l] = vmaxq_f32(maxv[32*l], maxv[32*l + 16]); + + const float min = vminvq_f32(minv[0]); + const float max = vmaxvq_f32(maxv[0]); + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + y[i].m = min; + + const float32x4_t minv0 = vdupq_n_f32(min); + + for (int l = 0; l < 32; l++) { + const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest + const int32x4_t vi = vcvtq_s32_f32(vf); + + y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + } +#else + // scalar + quantize_row_q4_2_reference(x, vy, k); +#endif +} + static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; @@ -1178,6 +1450,112 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #endif } +static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) { + assert(k % QK128 == 0); + const int nb = k / QK128; + + const block_q4_2 * restrict x = vx; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + const __m256 d_v = _mm256_broadcast_ss(&x[i].d); + const __m256 d_m = _mm256_broadcast_ss(&x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK128; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale, add m and store + for (int j = 0; j < 4; j++) { + const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); + _mm256_storeu_ps(y + i * QK128 + l + j*8, result); + } + } + } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float32x4_t vd = vdupq_n_f32(x[i].d); + const float32x4_t vm = vdupq_n_f32(x[i].m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK128; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit qs to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Interleave and combine + const uint8x8_t vx_0 = vzip1_u8(v0, v1); + const uint8x8_t vx_1 = vzip2_u8(v0, v1); + + const uint8x16_t vq = vcombine_u8(vx_0, vx_1); + + // convert to 2x uint16x8_t + const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq)); + const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); + + // multiply by d and add m + const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); + const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); + const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); + const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); + + // Store + vst1q_f32(y + i*QK128 + l + 0, r0); + vst1q_f32(y + i*QK128 + l + 4, r1); + vst1q_f32(y + i*QK128 + l + 8, r2); + vst1q_f32(y + i*QK128 + l + 12, r3); + } + } +#else + for (int i = 0; i < nb; i++) { + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK128; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK128 + l + 0] = v0; + y[i*QK128 + l + 1] = v1; + + assert(!isnan(y[i*QK128 + l + 0])); + assert(!isnan(y[i*QK128 + l + 1])); + } + } +#endif +} + // // simd mappings // @@ -2318,6 +2696,164 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest *s = sumf; } +static void ggml_vec_dot_q4_2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK128; + + const block_q4_2 * restrict x = vx; + const block_q4_2 * restrict y = vy; + + float sumf = 0.0; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + // Accumulator for constant offsets + float acc_offset = 0.0f; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float * d0 = &x[i].d; + const float * d1 = &y[i].d; + + const float * m0 = &x[i].m; + const float * m1 = &y[i].m; + + const __m256 d0v = _mm256_broadcast_ss( d0 ); + const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 m0v = _mm256_broadcast_ss( m0 ); + const __m256 m1v = _mm256_broadcast_ss( m1 ); + + // Compute combined scale for the block + const __m256 scale_01 = _mm256_mul_ps( d0v, d1v ); + + // Compute cross scales for the block + const __m256 scale_0 = _mm256_mul_ps( d0v, m1v ); + const __m256 scale_1 = _mm256_mul_ps( m0v, d1v ); + const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ ); + + const uint8_t * restrict x_pp = x[i].qs; + const uint8_t * restrict y_pp = y[i].qs; + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + for (int l = 0; l < QK128; l += 32) { + __m256i bx = bytesFromNibbles( x_pp + l/2); + __m256i by = bytesFromNibbles( y_pp + l/2); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. + + // Sign-extend first 16 signed bytes into int16_t + __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) ); + __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + // Compute products of int16_t integers, add pairwise + __m256i i32 = _mm256_madd_epi16( x16, y16 ); + + // Sign-extend last 16 signed bytes into int16_t vectors + __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) ); + __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + // Accumulate products of int16_t integers + i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) ); + + // compute sums of unsigned bytes in bx, by in blocks of 8. + // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000, + // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400. + // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ] + __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() ); + __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() ); + __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); + __m256 sums = _mm256_cvtepi32_ps( sumsi ); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( i32 ); + // Apply the scale, and accumulate + // acc += d0*d1*x*y + d0*m1*x + d1*m0*y + acc = _mm256_fmadd_ps( scale_01, p, acc ); + acc = _mm256_fmadd_ps( cross_scales, sums, acc ); + } + // acc_offset += m0*m1 (for each entry in the block) + acc_offset += (*m0)*(*m1); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ) + acc_offset * QK128; +#elif defined(__ARM_NEON) + float sum00 = 0.0f; + float sum01 = 0.0f; + float sum10 = 0.0f; + float sum11 = 0.0f; + + for (int i = 0; i < nb; ++i) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict y0 = &y[i + 0]; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + + const uint8_t * restrict x_pp = x0->qs; + const uint8_t * restrict y_pp = x0->qs; + + for (int l = 0; l < QK128; l += 32) { + const uint8x16_t v0_0 = vld1q_u8(x_pp + l/2); + const uint8x16_t v1_0 = vld1q_u8(y_pp + l/2); + + // and with 0xf + const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); + const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); + + const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); + const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); + + // dot product into uint16x8_t + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + + const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + + sum00 += x0->m*y0->m; + sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); + sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); + sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); + } + } + + sumf = QK128*sum00 + sum01 + sum10 + sum11; +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = x[i].d; + const float d1 = y[i].d; + + const float m0 = x[i].m; + const float m1 = y[i].m; + + const uint8_t * restrict p0 = x[i].qs; + const uint8_t * restrict p1 = y[i].qs; + + for (int j = 0; j < QK128/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { @@ -2564,22 +3100,22 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F16] = 1, [GGML_TYPE_Q4_0] = QK, [GGML_TYPE_Q4_1] = QK, + [GGML_TYPE_Q4_2] = QK128, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), [GGML_TYPE_F16] = sizeof(ggml_fp16_t), [GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_1] = sizeof(block_q4_1), + [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -3247,6 +3783,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3307,6 +3847,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3361,6 +3905,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3405,6 +3953,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3447,6 +3999,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3491,6 +4047,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_2: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -5229,6 +5789,7 @@ static void ggml_compute_forward_dup( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5310,6 +5871,7 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5362,6 +5924,7 @@ static void ggml_compute_forward_sub( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5414,6 +5977,7 @@ static void ggml_compute_forward_mul( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5466,6 +6030,7 @@ static void ggml_compute_forward_div( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5514,6 +6079,7 @@ static void ggml_compute_forward_sqr( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5562,6 +6128,7 @@ static void ggml_compute_forward_sqrt( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5620,6 +6187,7 @@ static void ggml_compute_forward_sum( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5697,6 +6265,7 @@ static void ggml_compute_forward_mean( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5761,6 +6330,7 @@ static void ggml_compute_forward_repeat( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5809,6 +6379,7 @@ static void ggml_compute_forward_abs( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5857,6 +6428,7 @@ static void ggml_compute_forward_sgn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5905,6 +6477,7 @@ static void ggml_compute_forward_neg( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5953,6 +6526,7 @@ static void ggml_compute_forward_step( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6001,6 +6575,7 @@ static void ggml_compute_forward_relu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6066,6 +6641,7 @@ static void ggml_compute_forward_gelu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6133,6 +6709,7 @@ static void ggml_compute_forward_silu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6219,6 +6796,7 @@ static void ggml_compute_forward_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6299,6 +6877,7 @@ static void ggml_compute_forward_rms_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6708,6 +7287,12 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, .vec_dot_q = ggml_vec_dot_q4_1, }, + [GGML_TYPE_Q4_2] = { + .dequantize_row_q = dequantize_row_q4_2, + .quantize_row_q = quantize_row_q4_2, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference, + .vec_dot_q = ggml_vec_dot_q4_2, + }, }; // For internal test use @@ -6915,6 +7500,10 @@ static void ggml_compute_forward_mul_mat( { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_2: + { + ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -6953,6 +7542,26 @@ static void ggml_compute_forward_mul_mat( printf("\n"); exit(0); } + } else if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_2) { + static int first = 8; + printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + if (first) { + --first; + } else { + for (int k = 0; k < dst->ne[1]; ++k) { + for (int j = 0; j < dst->ne[0]/16; ++j) { + for (int i = 0; i < 16; ++i) { + printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + exit(0); + } } else { printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); @@ -7010,6 +7619,7 @@ static void ggml_compute_forward_scale( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7178,6 +7788,10 @@ static void ggml_compute_forward_get_rows( { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_2: + { + ggml_compute_forward_get_rows_q(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_get_rows_f16(params, src0, src1, dst); @@ -7264,6 +7878,7 @@ static void ggml_compute_forward_diag_mask_inf( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7358,6 +7973,7 @@ static void ggml_compute_forward_soft_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7533,6 +8149,7 @@ static void ggml_compute_forward_rope( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7801,6 +8418,7 @@ static void ggml_compute_forward_conv_1d_1s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8069,6 +8687,7 @@ static void ggml_compute_forward_conv_1d_2s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8554,6 +9173,7 @@ static void ggml_compute_forward_flash_attn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8765,6 +9385,7 @@ static void ggml_compute_forward_flash_ff( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: