ggml : use 8-bit precision for Q4_1 intermediate results (ARM)

This commit is contained in:
Georgi Gerganov 2023-04-18 22:12:19 +03:00
parent 7faa7460f0
commit 7840f6637c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

662
ggml.c
View file

@ -514,6 +514,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
}
inline static int16_t vaddvq_s8(int8x16_t v) {
return
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
}
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@ -1420,8 +1432,8 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
#endif
}
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = {
@ -1435,8 +1447,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.dequantize_row_q = dequantize_row_q4_1,
.quantize_row_q = quantize_row_q4_1,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
.quantize_row_q_dot = quantize_row_q4_1,
.vec_dot_q = ggml_vec_dot_q4_1,
.quantize_row_q_dot = quantize_row_q8_0,
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
},
// TODO: GGML_TYPE_Q8_0
};
@ -2225,535 +2237,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
*s = sumf;
}
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK4_0;
assert(n % QK4_0 == 0);
assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx;
const block_q4_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON)
float sum0 = 0.0f;
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict y0 = &y[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q4_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
#endif
}
sumf = sum0 + sum1;
#elif defined(__AVX512F__)
// Initialize accumulator with zeros
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();
const int superblock_size = 16;
const int superblock_count = nb / superblock_size;
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size;
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
}
// Remainders
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
}
// Horizontal sum of all lanes of the accumulator
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
/* Prepare the constants we will need during execution */
const __m256i lowMask = _mm256_set1_epi8( 0xF );
const __m256i offset_8 = _mm256_set1_epi16( 8 );
#define UNROLL_COUNT 8
// make sure we only unroll multiples of the block count
assert(nb % UNROLL_COUNT == 0);
// Main loop
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
// This loop will be unrolled by the compiler
for (int u=0;u<UNROLL_COUNT;u++) {
/* Compute combined scale for the block */
const __m256 scale = _mm256_mul_ps(
_mm256_broadcast_ss( &x[i+u].d ),
_mm256_broadcast_ss( &y[i+u].d ) );
/* get input from x
Input: 32 Nibbles (16 bytes) at *x[i+u]
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
/* Load 16 bytes from memory */
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
/* Unpack values into individual bytes */
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
/* get input from y
Input: 32 Nibbles (16 bytes) at *y[i+u]
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
/* Load 16 bytes from memory */
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
/* Unpack values into individual bytes */
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
/* Compute products of int16_t integers, add pairwise, store as int32_t */
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( scale, q, acc );
}
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
__m128i by = bytesFromNibbles( y[i].qs + 8*j );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );
by = _mm_sub_epi8( by, off );
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
i32[j] = _mm_madd_epi16(ones, dot);
}
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#elif defined(__wasm_simd128__)
// wasm simd
float sum0 = 0.0f;
float sum1 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict y0 = &y[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q4_0 * restrict y1 = &y[i + 1];
const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8);
const v128_t v0_0 = wasm_v128_load(x0->qs);
const v128_t v0_1 = wasm_v128_load(y0->qs);
const v128_t v1_0 = wasm_v128_load(x1->qs);
const v128_t v1_1 = wasm_v128_load(y1->qs);
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
// sub 8
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
// dot product into int16x8_t
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
sum0 += x0->d * y0->d * (
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
sum1 += x1->d * y1->d * (
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
}
sumf = sum0 + sum1;
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
int sumi = 0;
for (int j = 0; j < QK4_0/2; j++) {
const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j];
const int i0 = (v0 & 0xf) - 8;
const int i1 = (v0 >> 4) - 8;
const int i2 = (v1 & 0xf) - 8;
const int i3 = (v1 >> 4) - 8;
sumi += i0*i2 + i1*i3;
}
sumf += d0 * d1 * sumi;
}
#endif
*s = sumf;
}
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK4_1;
const block_q4_1 * restrict x = vx;
const block_q4_1 * restrict y = vy;
float sumf = 0.0;
#if defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Accumulator for constant offsets
float acc_offset = 0.0f;
// Main loop
for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
const float * m0 = &x[i].m;
const float * m1 = &y[i].m;
const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
const __m256 m0v = _mm256_broadcast_ss( m0 );
const __m256 m1v = _mm256_broadcast_ss( m1 );
// Compute combined scale for the block
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
// Compute cross scales for the block
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
__m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( y[i].qs );
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
// Sign-extend first 16 signed bytes into int16_t
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
// Compute products of int16_t integers, add pairwise
__m256i i32 = _mm256_madd_epi16( x16, y16 );
// Sign-extend last 16 signed bytes into int16_t vectors
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
// Accumulate products of int16_t integers
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
// compute sums of unsigned bytes in bx, by in blocks of 8.
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
__m256 sums = _mm256_cvtepi32_ps( sumsi );
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 );
// Apply the scale, and accumulate
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
acc = _mm256_fmadd_ps( scale_01, p, acc );
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
// acc_offset += m0*m1 (for each entry in the block)
acc_offset += (*m0)*(*m1);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
#elif defined(__ARM_NEON)
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict y0 = &y[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
const block_q4_1 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
// 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
sum00 += x0->m*y0->m;
sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
sum00 += x1->m*y1->m;
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
#else
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
#endif
}
sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
const float m0 = x[i].m;
const float m1 = y[i].m;
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
for (int j = 0; j < QK4_1/2; j++) {
const uint8_t v0 = p0[j];
const uint8_t v1 = p1[j];
const float f0 = d0*(v0 & 0xf) + m0;
const float f1 = d0*(v0 >> 4) + m0;
const float f2 = d1*(v1 & 0xf) + m1;
const float f3 = d1*(v1 >> 4) + m1;
sumf += f0*f2 + f1*f3;
}
}
#endif
*s = sumf;
}
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0;
@ -2957,6 +2440,121 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
*s = sumf;
}
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0;
assert(n % QK8_0 == 0);
assert(nb % 2 == 0);
const block_q4_1 * restrict x = vx;
const block_q8_0 * restrict y = vy;
float sumf = 0.0;
// TODO: add AVX / WASM SIMD / etc
#if defined(__ARM_NEON)
float sum00 = 0.0f;
float sum01 = 0.0f;
float sum10 = 0.0f;
float sum11 = 0.0f;
for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
// interleave
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
// Note: cannot use vaddvq_s8 because it overflows for 8-bit values
// TODO: is there a better way to do this?
sum00 += (x0->m*y0->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_0ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0ls))) +
vaddvq_s16(vmovl_s8(vget_low_s8(v1_0hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0hs))));
sum01 += (x1->m*y1->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_1ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1ls))) +
vaddvq_s16(vmovl_s8(vget_low_s8(v1_1hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1hs))));
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
sum10 += (x0->d*y0->d)*vaddvq_s32(p_0);
sum11 += (x1->d*y1->d)*vaddvq_s32(p_1);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
sum10 += x0->d*y0->d*vaddvq_s16(p_0);
sum11 += x1->d*y1->d*vaddvq_s16(p_1);
#endif
}
sumf = sum00 + sum01 + sum10 + sum11;
#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float m0 = x[i].m;
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs;
const int8_t * restrict p1 = y[i].qs;
// TODO: this is very slow ..
for (int j = 0; j < QK8_0/2; j++) {
const uint8_t v0 = p0[j];
const float f0 = d0*(v0 & 0xf) + m0;
const float f1 = d0*(v0 >> 4) + m0;
const float f2 = d1*p1[2*j + 0];
const float f3 = d1*p1[2*j + 1];
sumf += f0*f2 + f1*f3;
}
}
#endif
*s = sumf;
}
// compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {