This commit is contained in:
qwopqwop200 2023-04-13 14:54:44 +09:00 committed by GitHub
parent f0b14e8c69
commit ff0efc747d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

625
ggml.c
View file

@ -411,6 +411,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
//
#define QK 32
#define QK128 128
// AVX routines provided by GH user Const-me
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
@ -502,6 +503,16 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
// method 6
// blocks of QK elements for GPTQ
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
typedef struct {
float d;
float m;
uint8_t qs[QK128 / 2]; // nibbles / quants
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(float) * 2 + QK128 / 2, "wrong q4_2 block size/padding");
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(k % QK == 0);
@ -954,6 +965,267 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
#endif
}
static void quantize_row_q4_2_reference(const float * restrict x, void * restrict vy, int k) {
assert(k % QK128 == 0);
const int nb = k / QK128;
block_q4_2 * restrict y = vy;
uint8_t pp[QK128/2];
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
for (int l = 0; l < QK128; l++) {
const float v = x[i*QK128 + l];
if (v < min) min = v;
if (v > max) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
y[i].m = min;
for (int l = 0; l < QK128; l += 2) {
const float v0 = (x[i*QK128 + l + 0] - min)*id;
const float v1 = (x[i*QK128 + l + 1] - min)*id;
const uint8_t vi0 = roundf(v0);
const uint8_t vi1 = roundf(v1);
assert(vi0 < 16);
assert(vi1 < 16);
pp[l/2] = vi0 | (vi1 << 4);
}
memcpy(y[i].qs, pp, sizeof(pp));
}
}
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0);
const int nb = k / QK;
block_q4_2 * restrict y = vy;
#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
__m256 v4 = _mm256_loadu_ps( x + 32 );
__m256 v5 = _mm256_loadu_ps( x + 40 );
__m256 v6 = _mm256_loadu_ps( x + 48 );
__m256 v7 = _mm256_loadu_ps( x + 56 );
__m256 v8 = _mm256_loadu_ps( x + 64 );
__m256 v9 = _mm256_loadu_ps( x + 72 );
__m256 v10 = _mm256_loadu_ps( x + 80 );
__m256 v11 = _mm256_loadu_ps( x + 88 );
__m256 v12 = _mm256_loadu_ps( x + 96 );
__m256 v13 = _mm256_loadu_ps( x + 104 );
__m256 v14 = _mm256_loadu_ps( x + 112 );
__m256 v15 = _mm256_loadu_ps( x + 120 );
x += 128;
// Compute max for the block
__m256 vmax;
vmax = _mm256_max_ps( v0, v1 );
vmax = _mm256_max_ps( vmax, v2 );
vmax = _mm256_max_ps( vmax, v3 );
vmax = _mm256_max_ps( vmax, v4 );
vmax = _mm256_max_ps( vmax, v5 );
vmax = _mm256_max_ps( vmax, v6 );
vmax = _mm256_max_ps( vmax, v7 );
vmax = _mm256_max_ps( vmax, v8 );
vmax = _mm256_max_ps( vmax, v9 );
vmax = _mm256_max_ps( vmax, v10 );
vmax = _mm256_max_ps( vmax, v11 );
vmax = _mm256_max_ps( vmax, v12 );
vmax = _mm256_max_ps( vmax, v13 );
vmax = _mm256_max_ps( vmax, v14 );
vmax = _mm256_max_ps( vmax, v15 );
__m128 max16 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
max16 = _mm_max_ps( max16, _mm_movehl_ps( max16, max16 ) );
max16 = _mm_max_ss( max16, _mm_movehdup_ps( max16 ) );
const float maxScalar = _mm_cvtss_f32( max16 );
// Compute min for the block
__m256 vmin;
vmin = _mm256_min_ps( v0, v1 );
vmin = _mm256_min_ps( vmin, v2 );
vmin = _mm256_min_ps( vmin, v3 );
vmin = _mm256_min_ps( vmin, v4 );
vmin = _mm256_min_ps( vmin, v5 );
vmin = _mm256_min_ps( vmin, v6 );
vmin = _mm256_min_ps( vmin, v7 );
vmin = _mm256_min_ps( vmin, v8 );
vmin = _mm256_min_ps( vmin, v9 );
vmin = _mm256_min_ps( vmin, v10 );
vmin = _mm256_min_ps( vmin, v11 );
vmin = _mm256_min_ps( vmin, v12 );
vmin = _mm256_min_ps( vmin, v13 );
vmin = _mm256_min_ps( vmin, v14 );
vmin = _mm256_min_ps( vmin, v15 );
__m128 min16 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
min16 = _mm_min_ps( min16, _mm_movehl_ps( min16, min16 ) );
min16 = _mm_min_ss( min16, _mm_movehdup_ps( min16 ) );
const float minScalar = _mm_cvtss_f32( min16 );
// Quantize these floats
const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].m = minScalar;
y[i].d = d;
// x = (x-min)*id
const __m256 mul = _mm256_set1_ps( id );
const __m256 off = _mm256_set1_ps( minScalar );
v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
v4 = _mm256_mul_ps( _mm256_sub_ps( v4, off ), mul );
v5 = _mm256_mul_ps( _mm256_sub_ps( v5, off ), mul );
v6 = _mm256_mul_ps( _mm256_sub_ps( v6, off ), mul );
v7 = _mm256_mul_ps( _mm256_sub_ps( v7, off ), mul );
v8 = _mm256_mul_ps( _mm256_sub_ps( v8, off ), mul );
v9 = _mm256_mul_ps( _mm256_sub_ps( v9, off ), mul );
v10 = _mm256_mul_ps( _mm256_sub_ps( v10, off ), mul );
v11 = _mm256_mul_ps( _mm256_sub_ps( v11, off ), mul );
v12 = _mm256_mul_ps( _mm256_sub_ps( v12, off ), mul );
v13 = _mm256_mul_ps( _mm256_sub_ps( v13, off ), mul );
v14 = _mm256_mul_ps( _mm256_sub_ps( v14, off ), mul );
v15 = _mm256_mul_ps( _mm256_sub_ps( v15, off ), mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
v4 = _mm256_round_ps( v4, _MM_ROUND_NEAREST );
v5 = _mm256_round_ps( v5, _MM_ROUND_NEAREST );
v6 = _mm256_round_ps( v6, _MM_ROUND_NEAREST );
v7 = _mm256_round_ps( v7, _MM_ROUND_NEAREST );
v8 = _mm256_round_ps( v8, _MM_ROUND_NEAREST );
v9 = _mm256_round_ps( v9, _MM_ROUND_NEAREST );
v10 = _mm256_round_ps( v10, _MM_ROUND_NEAREST );
v11 = _mm256_round_ps( v11, _MM_ROUND_NEAREST );
v12 = _mm256_round_ps( v12, _MM_ROUND_NEAREST );
v13 = _mm256_round_ps( v13, _MM_ROUND_NEAREST );
v14 = _mm256_round_ps( v14, _MM_ROUND_NEAREST );
v15 = _mm256_round_ps( v15, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
__m256i i4 = _mm256_cvtps_epi32( v4 );
__m256i i5 = _mm256_cvtps_epi32( v5 );
__m256i i6 = _mm256_cvtps_epi32( v6 );
__m256i i7 = _mm256_cvtps_epi32( v7 );
__m256i i8 = _mm256_cvtps_epi32( v8 );
__m256i i9 = _mm256_cvtps_epi32( v9 );
__m256i i10 = _mm256_cvtps_epi32( v10 );
__m256i i11 = _mm256_cvtps_epi32( v11 );
__m256i i12 = _mm256_cvtps_epi32( v12 );
__m256i i13 = _mm256_cvtps_epi32( v13 );
__m256i i14 = _mm256_cvtps_epi32( v14 );
__m256i i15 = _mm256_cvtps_epi32( v15 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
i4 = _mm256_packs_epi32( i4, i5 );
i6 = _mm256_packs_epi32( i6, i7 );
i8 = _mm256_packs_epi32( i8, i9 );
i10 = _mm256_packs_epi32( i10, i11 );
i12 = _mm256_packs_epi32( i12, i13 );
i14 = _mm256_packs_epi32( i14, i15 );
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
i4 = _mm256_packs_epi16( i4, i6 );
i8 = _mm256_packs_epi16( i8, i10 );
i12 = _mm256_packs_epi16( i12, i14 );
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
i4 = _mm256_permutevar8x32_epi32( i4, perm );
i8 = _mm256_permutevar8x32_epi32( i8, perm );
i12 = _mm256_permutevar8x32_epi32( i12, perm );
// Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 );
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
res = packNibbles( i4 );
_mm_storeu_si128( (( __m128i* )y[i].qs) + 1, res );
res = packNibbles( i8 );
_mm_storeu_si128( (( __m128i* )y[i].qs) + 2, res );
res = packNibbles( i12 );
_mm_storeu_si128( (( __m128i* )y[i].qs) + 3, res );
}
#elif __ARM_NEON
for (int i = 0; i < nb; i++) {
float32x4_t srcv[32];
float32x4_t minv[32];
float32x4_t maxv[32];
for (int l = 0; l < 32; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
for (int l = 0; l < 16; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 8; l++) minv[4*l] = vminq_f32(srcv[4*l], srcv[4*l + 2]);
for (int l = 0; l < 4; l++) minv[8*l] = vminq_f32(srcv[8*l], srcv[8*l + 4]);
for (int l = 0; l < 2; l++) minv[16*l] = vminq_f32(minv[16*l], minv[16*l + 8]);
for (int l = 0; l < 1; l++) minv[32*l] = vminq_f32(minv[32*l], minv[32*l + 16]);
for (int l = 0; l < 16; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 8; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
for (int l = 0; l < 4; l++) maxv[8*l] = vmaxq_f32(srcv[8*l], srcv[8*l + 4]);
for (int l = 0; l < 2; l++) maxv[16*l] = vmaxq_f32(maxv[16*l], maxv[16*l + 8]);
for (int l = 0; l < 1; l++) maxv[32*l] = vmaxq_f32(maxv[32*l], maxv[32*l + 16]);
const float min = vminvq_f32(minv[0]);
const float max = vmaxvq_f32(maxv[0]);
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
y[i].m = min;
const float32x4_t minv0 = vdupq_n_f32(min);
for (int l = 0; l < 32; l++) {
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
const int32x4_t vi = vcvtq_s32_f32(vf);
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
}
#else
// scalar
quantize_row_q4_2_reference(x, vy, k);
#endif
}
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
@ -1178,6 +1450,112 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
#endif
}
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
assert(k % QK128 == 0);
const int nb = k / QK128;
const block_q4_2 * restrict x = vx;
#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK128; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytesFromNibbles(pp+l/2);
// Convert to 16-bit int
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
// Convert to 32-bit int -> float 32
const __m256 vf[4] = {
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
};
// Scale, add m and store
for (int j = 0; j < 4; j++) {
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
_mm256_storeu_ps(y + i * QK128 + l + j*8, result);
}
}
}
#elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
const float32x4_t vd = vdupq_n_f32(x[i].d);
const float32x4_t vm = vdupq_n_f32(x[i].m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK128; l += 16) {
// Load 16x4-bit integers into 8x8-bit integers
const uint8x8_t v8 = vld1_u8(pp + l/2);
// Expand 4-bit qs to 8-bit bytes
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
const uint8x8_t v1 = vshr_n_u8(v8, 4);
// Interleave and combine
const uint8x8_t vx_0 = vzip1_u8(v0, v1);
const uint8x8_t vx_1 = vzip2_u8(v0, v1);
const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
// convert to 2x uint16x8_t
const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
// convert to 4x float32x4_t
const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
// multiply by d and add m
const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
// Store
vst1q_f32(y + i*QK128 + l + 0, r0);
vst1q_f32(y + i*QK128 + l + 4, r1);
vst1q_f32(y + i*QK128 + l + 8, r2);
vst1q_f32(y + i*QK128 + l + 12, r3);
}
}
#else
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK128; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
y[i*QK128 + l + 0] = v0;
y[i*QK128 + l + 1] = v1;
assert(!isnan(y[i*QK128 + l + 0]));
assert(!isnan(y[i*QK128 + l + 1]));
}
}
#endif
}
//
// simd mappings
//
@ -2318,6 +2696,164 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
*s = sumf;
}
static void ggml_vec_dot_q4_2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK128;
const block_q4_2 * restrict x = vx;
const block_q4_2 * restrict y = vy;
float sumf = 0.0;
#if defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Accumulator for constant offsets
float acc_offset = 0.0f;
// Main loop
for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
const float * m0 = &x[i].m;
const float * m1 = &y[i].m;
const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
const __m256 m0v = _mm256_broadcast_ss( m0 );
const __m256 m1v = _mm256_broadcast_ss( m1 );
// Compute combined scale for the block
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
// Compute cross scales for the block
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
const uint8_t * restrict x_pp = x[i].qs;
const uint8_t * restrict y_pp = y[i].qs;
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
for (int l = 0; l < QK128; l += 32) {
__m256i bx = bytesFromNibbles( x_pp + l/2);
__m256i by = bytesFromNibbles( y_pp + l/2);
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
// Sign-extend first 16 signed bytes into int16_t
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
// Compute products of int16_t integers, add pairwise
__m256i i32 = _mm256_madd_epi16( x16, y16 );
// Sign-extend last 16 signed bytes into int16_t vectors
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
// Accumulate products of int16_t integers
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
// compute sums of unsigned bytes in bx, by in blocks of 8.
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
__m256 sums = _mm256_cvtepi32_ps( sumsi );
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 );
// Apply the scale, and accumulate
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
acc = _mm256_fmadd_ps( scale_01, p, acc );
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
}
// acc_offset += m0*m1 (for each entry in the block)
acc_offset += (*m0)*(*m1);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ) + acc_offset * QK128;
#elif defined(__ARM_NEON)
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f;
for (int i = 0; i < nb; ++i) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict y0 = &y[i + 0];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8_t * restrict x_pp = x0->qs;
const uint8_t * restrict y_pp = x0->qs;
for (int l = 0; l < QK128; l += 32) {
const uint8x16_t v0_0 = vld1q_u8(x_pp + l/2);
const uint8x16_t v1_0 = vld1q_u8(y_pp + l/2);
// and with 0xf
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
// dot product into uint16x8_t
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
sum00 += x0->m*y0->m;
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
}
}
sumf = QK128*sum00 + sum01 + sum10 + sum11;
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
const float m0 = x[i].m;
const float m1 = y[i].m;
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
for (int j = 0; j < QK128/2; j++) {
const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j];
const float f0 = d0*(v0 & 0xf) + m0;
const float f1 = d0*(v0 >> 4) + m0;
const float f2 = d1*(v1 & 0xf) + m1;
const float f3 = d1*(v1 >> 4) + m1;
sumf += f0*f2 + f1*f3;
}
}
#endif
*s = sumf;
}
// compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@ -2564,22 +3100,22 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F16] = 1,
[GGML_TYPE_Q4_0] = QK,
[GGML_TYPE_Q4_1] = QK,
[GGML_TYPE_Q4_2] = QK128,
[GGML_TYPE_I8] = 1,
[GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1,
};
static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float),
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
[GGML_TYPE_I8] = sizeof(int8_t),
[GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t),
};
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
@ -3247,6 +3783,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -3307,6 +3847,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@ -3361,6 +3905,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3405,6 +3953,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3447,6 +3999,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -3491,6 +4047,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_Q4_2:
{
GGML_ASSERT(false);
} break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@ -5229,6 +5789,7 @@ static void ggml_compute_forward_dup(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5310,6 +5871,7 @@ static void ggml_compute_forward_add(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5362,6 +5924,7 @@ static void ggml_compute_forward_sub(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5414,6 +5977,7 @@ static void ggml_compute_forward_mul(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5466,6 +6030,7 @@ static void ggml_compute_forward_div(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5514,6 +6079,7 @@ static void ggml_compute_forward_sqr(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5562,6 +6128,7 @@ static void ggml_compute_forward_sqrt(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5620,6 +6187,7 @@ static void ggml_compute_forward_sum(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5697,6 +6265,7 @@ static void ggml_compute_forward_mean(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5761,6 +6330,7 @@ static void ggml_compute_forward_repeat(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5809,6 +6379,7 @@ static void ggml_compute_forward_abs(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5857,6 +6428,7 @@ static void ggml_compute_forward_sgn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5905,6 +6477,7 @@ static void ggml_compute_forward_neg(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -5953,6 +6526,7 @@ static void ggml_compute_forward_step(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6001,6 +6575,7 @@ static void ggml_compute_forward_relu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6066,6 +6641,7 @@ static void ggml_compute_forward_gelu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6133,6 +6709,7 @@ static void ggml_compute_forward_silu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6219,6 +6796,7 @@ static void ggml_compute_forward_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6299,6 +6877,7 @@ static void ggml_compute_forward_rms_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -6708,6 +7287,12 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
.vec_dot_q = ggml_vec_dot_q4_1,
},
[GGML_TYPE_Q4_2] = {
.dequantize_row_q = dequantize_row_q4_2,
.quantize_row_q = quantize_row_q4_2,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
.vec_dot_q = ggml_vec_dot_q4_2,
},
};
// For internal test use
@ -6915,6 +7500,10 @@ static void ggml_compute_forward_mul_mat(
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_2:
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
@ -6953,6 +7542,26 @@ static void ggml_compute_forward_mul_mat(
printf("\n");
exit(0);
}
} else if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_2) {
static int first = 8;
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
if (first) {
--first;
} else {
for (int k = 0; k < dst->ne[1]; ++k) {
for (int j = 0; j < dst->ne[0]/16; ++j) {
for (int i = 0; i < 16; ++i) {
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
}
printf("\n");
}
printf("\n");
}
printf("\n");
exit(0);
}
} else {
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
@ -7010,6 +7619,7 @@ static void ggml_compute_forward_scale(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7178,6 +7788,10 @@ static void ggml_compute_forward_get_rows(
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_2:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_get_rows_f16(params, src0, src1, dst);
@ -7264,6 +7878,7 @@ static void ggml_compute_forward_diag_mask_inf(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7358,6 +7973,7 @@ static void ggml_compute_forward_soft_max(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7533,6 +8149,7 @@ static void ggml_compute_forward_rope(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -7801,6 +8418,7 @@ static void ggml_compute_forward_conv_1d_1s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -8069,6 +8687,7 @@ static void ggml_compute_forward_conv_1d_2s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -8554,6 +9173,7 @@ static void ggml_compute_forward_flash_attn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@ -8765,6 +9385,7 @@ static void ggml_compute_forward_flash_ff(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32: