Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Georgi Gerganov
102cd98074 ggml : Q4_3c using 2x "Full range" approach 2023-04-23 14:56:44 +03:00
Georgi Gerganov
71e6ae3779 ggml : continue from #729 (wip) 2023-04-22 18:49:07 +03:00
Håkon H. Hitland
bd166f7ffc Fix type error in quantize_row_q4_1 for Arm NEON 2023-04-22 17:55:03 +03:00
Håkon H. Hitland
4282f9b0f3 Update quantize_row_q4_1 for PowerPC
Untested
2023-04-22 17:55:03 +03:00
Håkon H. Hitland
93c95fcc1b Update quantize_row_q4_0 for Arm NEON
Untested
2023-04-22 17:55:01 +03:00
Håkon H. Hitland
b7e704658e Update quantize_row_q4_0 for WASM
Untested
2023-04-22 17:54:19 +03:00
Håkon H. Hitland
5d5f2b2efa Update quantize_row_q4_0 for AVX/AVX2 2023-04-22 17:54:19 +03:00
Håkon H. Hitland
3698f79e6a Use full range for q4_0 quantization
By keeping the sign of the highest magnitude, we can make sure the
highest value maps to -8, which is currently unused.
This is a bit of a freebie since it is fully backwards compatible with
the current format.

quantize-stats output:
before(7B):
q4_0                                              : mse 0.00000492, maxerr 0.14257812
after(7B):
q4_0                                              : mse 0.00000386, maxerr 0.18200684

(Most layers have reduced maxerr under this rule, but the total max
error is indeed slightly higher)
2023-04-22 17:54:05 +03:00
2 changed files with 300 additions and 268 deletions

View file

@ -31,8 +31,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
#define QK4_3 16 #define QK4_3 16
typedef struct { typedef struct {
__half d; // delta __half d0; // delta
__half m; // min __half d1; // delta
uint8_t qs[QK4_3 / 2]; // nibbles / quants uint8_t qs[QK4_3 / 2]; // nibbles / quants
} block_q4_3; } block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
@ -112,22 +112,32 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
const int i = blockIdx.x; const int i = blockIdx.x;
const float d = x[i].d; const float d0 = x[i].d0;
const float m = x[i].m; const float d1 = x[i].d1;
const uint8_t * pp = x[i].qs; const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_3; l += 2) { for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi0 = pp[l/2];
const uint8_t vi1 = pp[l/2 + QK4_3/4];
const int8_t vi0 = vi & 0xf; const int8_t vi0_0 = vi0 & 0xf;
const int8_t vi1 = vi >> 4; const int8_t vi0_1 = vi0 >> 4;
const float v0 = vi0*d + m; const int8_t vi1_0 = vi1 & 0xf;
const float v1 = vi1*d + m; const int8_t vi1_1 = vi1 >> 4;
y[i*QK4_3 + l + 0] = v0; const float v0_0 = (vi0_0 - 8)*d0;
y[i*QK4_3 + l + 1] = v1; const float v0_1 = (vi0_1 - 8)*d0;
const float v1_0 = (vi1_0 - 8)*d1;
const float v1_1 = (vi1_1 - 8)*d1;
y[i*QK4_3 + l + 0] = v0_0;
y[i*QK4_3 + l + 1] = v0_1;
y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0;
y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1;
} }
} }

534
ggml.c
View file

@ -655,8 +655,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
#define QK4_3 16 #define QK4_3 16
typedef struct { typedef struct {
ggml_fp16_t d; // delta ggml_fp16_t d0; // delta
ggml_fp16_t m; // min ggml_fp16_t d1; // min
uint8_t qs[QK4_3 / 2]; // nibbles / quants uint8_t qs[QK4_3 / 2]; // nibbles / quants
} block_q4_3; } block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
@ -680,13 +680,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f;
for (int l = 0; l < QK4_0; l++) { for (int l = 0; l < QK4_0; l++) {
const float v = x[i*QK4_0 + l]; const float v = x[i*QK4_0 + l];
amax = MAX(amax, fabsf(v)); if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
} }
const float d = amax / ((1 << 3) - 1); const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = d; y[i].d = d;
@ -695,8 +699,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
const float v0 = x[i*QK4_0 + l + 0]*id; const float v0 = x[i*QK4_0 + l + 0]*id;
const float v1 = x[i*QK4_0 + l + 1]*id; const float v1 = x[i*QK4_0 + l + 1]*id;
const uint8_t vi0 = (int8_t)roundf(v0) + 8; const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
const uint8_t vi1 = (int8_t)roundf(v1) + 8; const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
assert(vi0 < 16); assert(vi0 < 16);
assert(vi1 < 16); assert(vi1 < 16);
@ -716,28 +720,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
#if defined(__POWER9_VECTOR__) #if defined(__POWER9_VECTOR__)
const vector float v85 = vec_splats(8.5f); const vector float v85 = vec_splats(8.5f);
const vector signed int v15 = vec_splats(15);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float max = 0.0f;
float min = 0.0f;
vector float srcv [8]; vector float srcv [8];
vector float asrcv[8]; vector float maxv[8];
vector float amaxv[8]; vector float minv[8];
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]); //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
//for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]); //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
amaxv[0] = vec_max(amaxv[0], amaxv[2]); maxv[0] = vec_max(maxv[0], maxv[2]);
amaxv[4] = vec_max(amaxv[4], amaxv[6]); maxv[4] = vec_max(maxv[4], maxv[6]);
//for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]); //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
amaxv[0] = vec_max(amaxv[0], amaxv[4]); maxv[0] = vec_max(maxv[0], maxv[4]);
amax = MAX( for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)), //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3))); minv[0] = vec_min(minv[0], minv[2]);
minv[4] = vec_min(minv[4], minv[6]);
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
minv[0] = vec_min(minv[0], minv[4]);
const float d = amax / ((1 << 3) - 1);
max = MAX(
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
min = MIN(
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
y[i].d = d; y[i].d = d;
@ -747,27 +765,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
for (int l = 0; l < 8; l++) { for (int l = 0; l < 8; l++) {
const vector float vf = vec_madd(srcv[l], vid, v85); const vector float vf = vec_madd(srcv[l], vid, v85);
const vector signed int vi = vec_signed(vf); const vector signed int vi = vec_signed(vf);
const vector signed int vc = vec_min(vi, v15);
pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4); pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4); pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
} }
} }
#elif __ARM_NEON #elif __ARM_NEON
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float32x4_t srcv [8]; float32x4_t srcv [8];
float32x4_t asrcv[8]; float32x4_t maxv[8];
float32x4_t amaxv[8]; float32x4_t minv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
const float amax = vmaxvq_f32(amaxv[0]); for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
const float d = amax / ((1 << 3) - 1); const float max = vmaxvq_f32(maxv[0]);
const float min = vminvq_f32(minv[0]);
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = d; y[i].d = d;
@ -776,9 +800,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const float32x4_t v = vmulq_n_f32(srcv[l], id); const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
const int32x4_t vi = vcvtq_s32_f32(vf); const int32x4_t vi = vcvtq_s32_f32(vf);
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
} }
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
@ -790,22 +815,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 ); __m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32; x += 32;
// Compute max(abs(e)) for the block // Compute max for the block
const __m256 signBit = _mm256_set1_ps( -0.0f ); __m256 max = _mm256_max_ps( v0, v1 );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); __m256 maxTmp = _mm256_max_ps( v2, v3 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); max = _mm256_max_ps( max, maxTmp );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 ); const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats // Quantize these floats
const float d = maxScalar / 7.0f; const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d; y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id ); const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier // Apply the multiplier
@ -838,9 +872,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm ); i0 = _mm256_permutevar8x32_epi32( i0, perm );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m256i off = _mm256_set1_epi8( 8 ); const __m256i off = _mm256_set1_epi8( 8 );
i0 = _mm256_add_epi8( i0, off ); i0 = _mm256_add_epi8( i0, off );
const __m256i maxNibble = _mm256_set1_epi8( 15 );
i0 = _mm256_min_epi8( i0, maxNibble );
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 ); __m128i res = packNibbles( i0 );
@ -855,22 +891,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 ); __m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32; x += 32;
// Compute max(abs(e)) for the block // Compute max for the block
const __m256 signBit = _mm256_set1_ps( -0.0f ); __m256 max = _mm256_max_ps( v0, v1 );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); __m256 maxTmp = _mm256_max_ps( v2, v3 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); max = _mm256_max_ps( max, maxTmp );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 ); const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats // Quantize these floats
const float d = maxScalar / 7.0f; const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d; y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f; const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id ); const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier // Apply the multiplier
@ -911,10 +956,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
ni0 = _mm_packs_epi16( ni0, ni2 ); ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 ); ni4 = _mm_packs_epi16( ni4, ni6 );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ] // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m128i off = _mm_set1_epi8( 8); const __m128i off = _mm_set1_epi8( 8 );
ni0 = _mm_add_epi8( ni0, off ); ni0 = _mm_add_epi8( ni0, off );
ni4 = _mm_add_epi8( ni4, off ); ni4 = _mm_add_epi8( ni4, off );
const __m128i maxNibble = _mm_set1_epi8( 15 );
ni0 = _mm_min_epi8( ni0, maxNibble );
ni4 = _mm_min_epi8( ni4, maxNibble );
// Compress the vector into 4 bit/value, and store // Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( ni0, ni4 ); __m128i res = packNibbles( ni0, ni4 );
@ -922,24 +970,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
} }
#elif defined(__wasm_simd128__) #elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float max = 0.0f;
float min = 0.0f;
v128_t srcv [8]; v128_t srcv [8];
v128_t asrcv[8]; v128_t maxv[8];
v128_t amaxv[8]; v128_t minv[8];
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
amax = MAX( for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
const float d = amax / ((1 << 3) - 1); max = MAX(
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
min = MIN(
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
const float magnitude = max >= fabsf(min) ? max : min;
const float d = magnitude / -8;
const float id = d ? 1.0/d : 0.0; const float id = d ? 1.0/d : 0.0;
y[i].d = d; y[i].d = d;
@ -948,9 +1004,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
} }
} }
#else #else
@ -1131,13 +1188,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f;
for (int l = 0; l < QK4_2; l++) { for (int l = 0; l < QK4_2; l++) {
const float v = x[i*QK4_2 + l]; const float v = x[i*QK4_2 + l];
amax = MAX(amax, fabsf(v)); if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
} }
const float d = amax / ((1 << 3) - 1); const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
@ -1147,8 +1208,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
const float v0 = x[i*QK4_2 + l + 0]*id; const float v0 = x[i*QK4_2 + l + 0]*id;
const float v1 = x[i*QK4_2 + l + 1]*id; const float v1 = x[i*QK4_2 + l + 1]*id;
const uint8_t vi0 = (uint8_t)(v0 + 8.5f); const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
const uint8_t vi1 = (uint8_t)(v1 + 8.5f); const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
assert(vi0 < 16); assert(vi0 < 16);
assert(vi1 < 16); assert(vi1 < 16);
@ -1158,93 +1219,12 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
} }
} }
static inline int nearest_int(float fval) {
assert(fval <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
const float * restrict candidates, int8_t * restrict L) {
assert (nmin >= INT8_MIN);
assert (nmax <= INT8_MAX);
float amax = 0;
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
if (!amax) { // all zero
for (int i=0; i<n; ++i) L[i] = 0;
return 1.f;
}
float best = 0, bestScale = 0;
for (int si=0; si<nCandidates; ++si) {
float iscale = candidates[si]/amax;
float sumlxP = 0; int suml2P = 0;
float sumlxM = 0; int suml2M = 0;
for (int i=0; i<n; ++i) {
int l = nearest_int(iscale*X[i]);
int lp = MAX(nmin, MIN(nmax, +l));
int lm = MAX(nmin, MIN(nmax, -l));
sumlxP += X[i]*lp; suml2P += lp*lp;
sumlxM += X[i]*lm; suml2M += lm*lm;
}
float sumlxP2 = sumlxP*sumlxP;
float sumlxM2 = sumlxM*sumlxM;
if (sumlxP2*suml2M > sumlxM2*suml2P) {
if (sumlxP2 > best*suml2P) {
best = sumlxP2/suml2P; bestScale = iscale;
}
} else {
if (sumlxM2 > best*suml2M) {
best = sumlxM2/suml2M; bestScale = -iscale;
}
}
}
float sumlx = 0; int suml2 = 0;
for (int i=0; i<n; ++i) {
int l = nearest_int(bestScale*X[i]);
l = MAX(nmin, MIN(nmax, l));
sumlx += X[i]*l; suml2 += l*l;
L[i] = l;
}
float scale = sumlx/suml2;
return scale;
}
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
#define CANDIDATE_COUNT 8
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
assert(k % QK4_2 == 0);
int8_t L[QK4_2];
const int nb = k / QK4_2;
for (int i = 0; i < nb; i++) {
float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
y[i].d = GGML_FP32_TO_FP16(scale);
for (int l = 0; l < QK4_2; l += 2) {
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
assert(vi0 < 16);
assert(vi1 < 16);
y[i].qs[l/2] = vi0 | (vi1 << 4);
}
x += QK4_2;
}
}
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
assert(k % QK4_2 == 0); assert(k % QK4_2 == 0);
block_q4_2 * restrict y = vy; block_q4_2 * restrict y = vy;
//quantize_row_q4_2_reference(x, y, k); quantize_row_q4_2_reference(x, y, k);
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
quantize_row_q4_2_rmse(x, y, k);
} }
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) { static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@ -1252,32 +1232,50 @@ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * r
const int nb = k / QK4_3; const int nb = k / QK4_3;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float min = FLT_MAX; float amax0 = 0.0f;
float max = -FLT_MAX; float max0 = 0.0f;
float amax1 = 0.0f;
float max1 = 0.0f;
for (int l = 0; l < QK4_3; l++) { for (int l = 0; l < QK4_3/2; l++) {
const float v = x[i*QK4_3 + l]; const float v0 = x[i*QK4_3 + l];
if (v < min) min = v; const float v1 = x[i*QK4_3 + l + QK4_3/2];
if (v > max) max = v;
if (amax0 < fabsf(v0)) {
amax0 = fabsf(v0);
max0 = v0;
}
if (amax1 < fabsf(v1)) {
amax1 = fabsf(v1);
max1 = v1;
}
} }
const float d = (max - min) / ((1 << 4) - 1); const float d0 = max0 / -8;
const float id = d ? 1.0f/d : 0.0f; const float d1 = max1 / -8;
y[i].d = GGML_FP32_TO_FP16(d); const float id0 = d0 ? 1.0f/d0 : 0.0f;
y[i].m = GGML_FP32_TO_FP16(min); const float id1 = d1 ? 1.0f/d1 : 0.0f;
for (int l = 0; l < QK4_3; l += 2) { y[i].d0 = GGML_FP32_TO_FP16(d0);
const float v0 = (x[i*QK4_3 + l + 0] - min)*id; y[i].d1 = GGML_FP32_TO_FP16(d1);
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
const uint8_t vi0 = (int) (v0 + 0.5f); for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi1 = (int) (v1 + 0.5f); const float v0_0 = x[i*QK4_3 + l + 0]*id0;
const float v0_1 = x[i*QK4_3 + l + 1]*id0;
assert(vi0 < 16); const float v1_0 = x[i*QK4_3 + l + 0 + QK4_3/2]*id1;
assert(vi1 < 16); const float v1_1 = x[i*QK4_3 + l + 1 + QK4_3/2]*id1;
y[i].qs[l/2] = vi0 | (vi1 << 4); const uint8_t vi0_0 = MIN(15, (int8_t)roundf(v0_0) + 8);
const uint8_t vi0_1 = MIN(15, (int8_t)roundf(v0_1) + 8);
const uint8_t vi1_0 = MIN(15, (int8_t)roundf(v1_0) + 8);
const uint8_t vi1_1 = MIN(15, (int8_t)roundf(v1_1) + 8);
y[i].qs[l/2 ] = vi0_0 | (vi0_1 << 4);
y[i].qs[l/2 + QK4_3/4] = vi1_0 | (vi1_1 << 4);
} }
} }
} }
@ -1749,25 +1747,32 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
const block_q4_3 * restrict x = vx; const block_q4_3 * restrict x = vx;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d); const float d0 = GGML_FP16_TO_FP32(x[i].d0);
const float m = GGML_FP16_TO_FP32(x[i].m); const float d1 = GGML_FP16_TO_FP32(x[i].d1);
const uint8_t * restrict pp = x[i].qs; const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_3; l += 2) { for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi = pp[l/2]; const uint8_t vi0 = pp[l/2];
const uint8_t vi1 = pp[l/2 + QK4_3/4];
const int8_t vi0 = vi & 0xf; const int8_t vi0_0 = vi0 & 0xf;
const int8_t vi1 = vi >> 4; const int8_t vi0_1 = vi0 >> 4;
const float v0 = vi0*d + m; const int8_t vi1_0 = vi1 & 0xf;
const float v1 = vi1*d + m; const int8_t vi1_1 = vi1 >> 4;
y[i*QK4_3 + l + 0] = v0; const float v0_0 = (vi0_0 - 8)*d0;
y[i*QK4_3 + l + 1] = v1; const float v0_1 = (vi0_1 - 8)*d0;
assert(!isnan(y[i*QK4_3 + l + 0])); const float v1_0 = (vi1_0 - 8)*d1;
assert(!isnan(y[i*QK4_3 + l + 1])); const float v1_1 = (vi1_1 - 8)*d1;
y[i*QK4_3 + l + 0] = v0_0;
y[i*QK4_3 + l + 1] = v0_1;
y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0;
y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1;
} }
} }
} }
@ -1795,7 +1800,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_2] = { [GGML_TYPE_Q4_2] = {
.dequantize_row_q = dequantize_row_q4_2, .dequantize_row_q = dequantize_row_q4_2,
.quantize_row_q = quantize_row_q4_2, .quantize_row_q = quantize_row_q4_2,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
.quantize_row_q_dot = quantize_row_q8_0, .quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_q4_2_q8_0, .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
}, },
@ -2876,17 +2881,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
assert(n % QK8_0 == 0); assert(n % QK8_0 == 0);
assert(nb % 2 == 0); assert(nb % 2 == 0);
assert(QK8_0 == 2*QK4_2); assert(QK8_0 == 2*QK4_3);
const block_q4_3 * restrict x = vx; const block_q4_3 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x2_t sumv0 = vdup_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x2_t sumv1 = vdup_n_f32(0.0f);
float32x2_t sumv2 = vdup_n_f32(0.0f);
float summs0 = 0.0f; float32x2_t sumv3 = vdup_n_f32(0.0f);
float summs1 = 0.0f;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
@ -2894,29 +2898,46 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y0 = &y[i + 0];
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; const uint8x8_t v0_0 = vld1_u8(x0_0->qs);
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; const uint8x8_t v0_1 = vld1_u8(x0_1->qs);
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
// 4-bit -> 8-bit // 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); const int8x8_t v0_0l = vreinterpret_s8_u8(vand_u8 (v0_0, vdup_n_u8(0xf)));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); const int8x8_t v0_0h = vreinterpret_s8_u8(vshr_n_u8(v0_0, 4));
const int8x8_t v0_1l = vreinterpret_s8_u8(vand_u8 (v0_1, vdup_n_u8(0xf)));
const int8x8_t v0_1h = vreinterpret_s8_u8(vshr_n_u8(v0_1, 4));
// sub 8
const int8x8_t v0_0ls = vsub_s8(v0_0l, vdup_n_s8(8));
const int8x8_t v0_0hs = vsub_s8(v0_0h, vdup_n_s8(8));
const int8x8_t v0_1ls = vsub_s8(v0_1l, vdup_n_s8(8));
const int8x8_t v0_1hs = vsub_s8(v0_1h, vdup_n_s8(8));
// interleave // interleave
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); const int8x8_t v0_0lz = vzip1_s8(v0_0ls, v0_0hs);
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); const int8x8_t v0_0hz = vzip2_s8(v0_0ls, v0_0hs);
const int8x8_t v0_1lz = vzip1_s8(v0_1ls, v0_1hs);
const int8x8_t v0_1hz = vzip2_s8(v0_1ls, v0_1hs);
// load y // load y
const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x8_t v1_0l = vld1_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x8_t v1_0h = vld1_s8(y0->qs + 8);
const int8x8_t v1_1l = vld1_s8(y0->qs + 16);
const int8x8_t v1_1h = vld1_s8(y0->qs + 24);
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); const float x0_0d = GGML_FP16_TO_FP32(x0_0->d0);
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); const float x0_1d = GGML_FP16_TO_FP32(x0_0->d1);
const float x1_0d = GGML_FP16_TO_FP32(x0_1->d0);
const float x1_1d = GGML_FP16_TO_FP32(x0_1->d1);
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); //sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); //sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
sumv0 = vmla_n_f32(sumv0, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv1 = vmla_n_f32(sumv1, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
sumv2 = vmla_n_f32(sumv2, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1lz, v1_1l)), x1_0d*y0->d);
sumv3 = vmla_n_f32(sumv3, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1hz, v1_1h)), x1_1d*y0->d);
#else #else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
@ -2931,77 +2952,79 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif #endif
} }
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; *s = vaddv_f32(vadd_f32(vadd_f32(sumv0, sumv1), vadd_f32(sumv2, sumv3)));
#elif defined(__AVX2__) #elif defined(__AVX2__)
GGML_ASSERT(false); // TODO
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); //__m256 acc = _mm256_setzero_ps();
// Main loop //// Main loop
for (int i = 0; i < nb; i++) { //for (int i = 0; i < nb; i++) {
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); // const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); // const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 dx = _mm256_set_m128(d1, d0); // const __m256 dx = _mm256_set_m128(d1, d0);
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); // const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); // const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
const __m256 mx = _mm256_set_m128(m1, m0); // const __m256 mx = _mm256_set_m128(m1, m0);
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); // const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); // const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
const __m256i bx = _mm256_set_m128i(bx1, bx0); // const __m256i bx = _mm256_set_m128i(bx1, bx0);
const __m256 dy = _mm256_broadcast_ss(&y[i].d); // const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); // const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); // const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
const __m256 syf = sum_i16_pairs_float(syi); // const __m256 syf = sum_i16_pairs_float(syi);
const __m256 q = mul_sum_i8_pairs_float(bx, by); // const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); // const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
acc = _mm256_fmadd_ps(sxy, dy, acc); // acc = _mm256_fmadd_ps(sxy, dy, acc);
} //}
*s = hsum_float_8(acc); //*s = hsum_float_8(acc);
#else #else
// scalar GGML_ASSERT(false); // TODO
float sumf = 0.0; //// scalar
for (int i = 0; i < nb; i++) { //float sumf = 0.0;
const uint8_t * restrict x0 = x[2*i + 0].qs; //for (int i = 0; i < nb; i++) {
const uint8_t * restrict x1 = x[2*i + 1].qs; // const uint8_t * restrict x0 = x[2*i + 0].qs;
const int8_t * restrict y0 = y[i].qs; // const uint8_t * restrict x1 = x[2*i + 1].qs;
// const int8_t * restrict y0 = y[i].qs;
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); // const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); // const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); // const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); // const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
int sxy_0 = 0; // int sxy_0 = 0;
int sxy_1 = 0; // int sxy_1 = 0;
for (int j = 0; j < QK8_0/4; j++) { // for (int j = 0; j < QK8_0/4; j++) {
const uint8_t v0 = x0[j]; // const uint8_t v0 = x0[j];
const uint8_t v1 = x1[j]; // const uint8_t v1 = x1[j];
const int x0_0 = v0 & 0xf; // const int x0_0 = v0 & 0xf;
const int x1_0 = v0 >> 4; // const int x1_0 = v0 >> 4;
const int x0_1 = v1 & 0xf; // const int x0_1 = v1 & 0xf;
const int x1_1 = v1 >> 4; // const int x1_1 = v1 >> 4;
const int y0_0 = y0[2*j + 0]; // const int y0_0 = y0[2*j + 0];
const int y1_0 = y0[2*j + 1]; // const int y1_0 = y0[2*j + 1];
const int y0_1 = y0[2*(j + QK8_0/4) + 0]; // const int y0_1 = y0[2*(j + QK8_0/4) + 0];
const int y1_1 = y0[2*(j + QK8_0/4) + 1]; // const int y1_1 = y0[2*(j + QK8_0/4) + 1];
sxy_0 += x0_0*y0_0 + x1_0*y1_0; // sxy_0 += x0_0*y0_0 + x1_0*y1_0;
sxy_1 += x0_1*y0_1 + x1_1*y1_1; // sxy_1 += x0_1*y0_1 + x1_1*y1_1;
} // }
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; // sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
} //}
*s = sumf; //*s = sumf;
#endif #endif
} }
@ -12127,8 +12150,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
for (int j = 0; j < n; j += k) { for (int j = 0; j < n; j += k) {
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2; block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
//quantize_row_q4_2_reference(src + j, y, k); quantize_row_q4_2_reference(src + j, y, k);
quantize_row_q4_2_rmse(src + j, y, k);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
for (int l = 0; l < QK4_2; l += 2) { for (int l = 0; l < QK4_2; l += 2) {