ggml : x2 speed for WASM by optimizing SIMD

This commit is contained in:
Xuan Son Nguyen 2025-01-27 15:04:44 +01:00
parent a5203b4465
commit 610b3ac3cd

View file

@ -747,7 +747,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
} }
} }
#elif defined(__wasm_simd128__) #elif defined __wasm_simd128__
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
v128_t srcv [8]; v128_t srcv [8];
v128_t asrcv[8]; v128_t asrcv[8];
@ -1037,7 +1037,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
} }
#elif defined(__wasm_simd128__) #elif defined __wasm_simd128__
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
v128_t srcv [8]; v128_t srcv [8];
v128_t asrcv[8]; v128_t asrcv[8];
@ -1653,7 +1653,105 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
//===================================== Q8_K ============================================== //===================================== Q8_K ==============================================
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
#ifdef __wasm_simd128__
assert(k % QK_K == 0);
const int64_t nb = k / QK_K;
block_q8_K * restrict yc = y; // Cast to proper type
for (int i = 0; i < nb; i++) {
const float * x_block = x + i * QK_K;
v128_t amax_vec = wasm_f32x4_splat(0.0f);
v128_t max_vec = wasm_f32x4_splat(0.0f);
// Vectorized max abs value search
for (int j = 0; j < QK_K; j += 4) {
v128_t x_vec = wasm_v128_load(x_block + j);
v128_t abs_x = wasm_f32x4_abs(x_vec);
v128_t mask = wasm_f32x4_gt(abs_x, amax_vec);
amax_vec = wasm_v128_bitselect(abs_x, amax_vec, mask);
max_vec = wasm_v128_bitselect(x_vec, max_vec, mask);
}
// Manual unroll for lane extraction
float amax = wasm_f32x4_extract_lane(amax_vec, 0);
float max_val = wasm_f32x4_extract_lane(max_vec, 0);
#define UPDATE_MAX(lane) \
{ \
float a = wasm_f32x4_extract_lane(amax_vec, lane); \
if (a > amax) { \
amax = a; \
max_val = wasm_f32x4_extract_lane(max_vec, lane); \
} \
}
UPDATE_MAX(1)
UPDATE_MAX(2)
UPDATE_MAX(3)
#undef UPDATE_MAX
if (amax == 0.0f) {
yc[i].d = 0.0f;
const v128_t zero = wasm_i8x16_splat(0);
for (int j = 0; j < QK_K; j += 16) {
wasm_v128_store(yc[i].qs + j, zero);
}
memset(yc[i].bsums, 0, QK_K/16 * sizeof(int));
continue;
}
const float iscale = -127.0f / max_val;
const v128_t scale_vec = wasm_f32x4_splat(iscale);
// Process 16 elements per iteration
for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
// Load and quantize 16 floats
v128_t x0 = wasm_v128_load(x_block + j);
v128_t x1 = wasm_v128_load(x_block + j + 4);
v128_t x2 = wasm_v128_load(x_block + j + 8);
v128_t x3 = wasm_v128_load(x_block + j + 12);
v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
// Convert to i32 with saturation
v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
// Pack into 16 i8 values
v128_t i8 = wasm_i8x16_narrow_i16x8(
wasm_i16x8_narrow_i32x4(
wasm_i32x4_min(wasm_i32x4_max(i0, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)),
wasm_i32x4_min(wasm_i32x4_max(i1, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127))
),
wasm_i16x8_narrow_i32x4(
wasm_i32x4_min(wasm_i32x4_max(i2, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)),
wasm_i32x4_min(wasm_i32x4_max(i3, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127))
)
);
wasm_v128_store(yc[i].qs + j, i8);
// Calculate bsums using SIMD
v128_t sum16 = wasm_i16x8_add(
wasm_i16x8_extend_low_i8x16(i8),
wasm_i16x8_extend_high_i8x16(i8)
);
v128_t sum32 = wasm_i32x4_add(
wasm_i32x4_extend_low_i16x8(sum16),
wasm_i32x4_extend_high_i16x8(sum16)
);
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
}
yc[i].d = 1.0f / iscale;
}
#else
quantize_row_q8_K_ref(x, y, k); quantize_row_q8_K_ref(x, y, k);
#endif
} }
//===================================== Dot products ================================= //===================================== Dot products =================================
@ -2011,6 +2109,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
const v128_t m4b = wasm_i8x16_splat(0x0F);
const v128_t s8b = wasm_i8x16_splat(0x8);
for (; ib + 1 < nb; ib += 2) {
const block_q4_0 * restrict x0 = &x[ib];
const block_q4_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib];
const block_q8_0 * restrict y1 = &y[ib + 1];
// Load and process x0
v128_t v0_0 = wasm_v128_load(x0->qs);
v128_t v0_0l = wasm_v128_and(v0_0, m4b);
v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
// Load y0 vectors
v128_t y0_l = wasm_v128_load(y0->qs);
v128_t y0_h = wasm_v128_load(y0->qs + 16);
// Extend to i16x8 and compute dot products
v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
v128_t dp0 = wasm_i32x4_add(
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx0l, dy0ll),
wasm_i32x4_dot_i16x8(dx0h, dy0lh)
),
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
)
);
// Load and process x1
v128_t v0_1 = wasm_v128_load(x1->qs);
v128_t v0_1l = wasm_v128_and(v0_1, m4b);
v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
// Load y1 vectors
v128_t y1_l = wasm_v128_load(y1->qs);
v128_t y1_h = wasm_v128_load(y1->qs + 16);
// Extend to i16x8 and compute dot products
v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
v128_t dp1 = wasm_i32x4_add(
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx1l, dy1ll),
wasm_i32x4_dot_i16x8(dx1h, dy1lh)
),
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
)
);
// Accumulate results with scaling
float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
}
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
#elif defined(__AVX2__) #elif defined(__AVX2__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2696,10 +2882,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__wasm_simd128__) #elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f); v128_t sumv = wasm_f32x4_splat(0.0f);
uint32_t qh; uint32_t qh_;
uint64_t tmp[4]; uint64_t tmp[4];
// TODO: check if unrolling this is better // TODO: check if unrolling this is better
@ -2710,12 +2896,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const v128_t m4b = wasm_i8x16_splat(0x0F); const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit // extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh)); memcpy(&qh_, x0->qh, sizeof(qh_));
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
tmp[3] = table_b2b_1[(qh >> 24) ]; tmp[3] = table_b2b_1[(qh_ >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2); const v128_t qhh = wasm_v128_load(tmp + 2);
@ -3057,12 +3243,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
#elif defined(__wasm_simd128__) #elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f); v128_t sumv = wasm_f32x4_splat(0.0f);
float summs = 0.0f; float summs = 0.0f;
uint32_t qh; uint32_t qh_;
uint64_t tmp[4]; uint64_t tmp[4];
// TODO: check if unrolling this is better // TODO: check if unrolling this is better
@ -3075,12 +3261,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const v128_t m4b = wasm_i8x16_splat(0x0F); const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit // extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh)); memcpy(&qh_, x0->qh, sizeof(qh_));
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
tmp[3] = table_b2b_0[(qh >> 24) ]; tmp[3] = table_b2b_0[(qh_ >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2); const v128_t qhh = wasm_v128_load(tmp + 2);
@ -3573,6 +3759,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
for (; ib < nb; ++ib) {
const block_q8_0 * restrict x0 = &x[ib];
const block_q8_0 * restrict y0 = &y[ib];
const v128_t x0_0 = wasm_v128_load(x0->qs);
const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
const v128_t y0_0 = wasm_v128_load(y0->qs);
const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
// Extend 8-bit to 16-bit
const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
// Compute dot products
const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
// Sum all dot products
const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
// Convert to float and accumulate
const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
}
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
#elif defined(__AVX2__) #elif defined(__AVX2__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -4447,6 +4672,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#elif defined __wasm_simd128__
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
// Vectorized summs calculation
v128_t summs_vec = wasm_i32x4_splat(0);
{
v128_t sc_vec = wasm_v128_load(sc);
v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
summs_vec = wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
wasm_i32x4_dot_i16x8(sc_high, bsums2)),
summs_vec
);
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
}
int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
// Vectorized isum calculation
int32_t isum = 0;
const uint8_t * sc_ptr = sc;
const int k_iters = QK_K/128;
for (int k = 0; k < k_iters; ++k) {
v128_t isum_vec = wasm_i32x4_splat(0);
int shift = 0;
for (int j = 0; j < 4; ++j) {
const int d0 = (sc_ptr[0] & 0xF);
const int d1 = (sc_ptr[1] & 0xF);
sc_ptr += 2;
// Process first 16 elements
v128_t q2_0 = wasm_v128_load(q2);
v128_t q8_0 = wasm_v128_load(q8);
v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
// Process next 16 elements
v128_t q2_1 = wasm_v128_load(q2 + 16);
v128_t q8_1 = wasm_v128_load(q8 + 16);
v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
// Calculate dot products
v128_t p0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q8_0),
wasm_i16x8_extend_low_i8x16(q2_bits_0)
);
v128_t p1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q8_0),
wasm_i16x8_extend_high_i8x16(q2_bits_0)
);
v128_t p2 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q8_1),
wasm_i16x8_extend_low_i8x16(q2_bits_1)
);
v128_t p3 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q8_1),
wasm_i16x8_extend_high_i8x16(q2_bits_1)
);
// Accumulate scaled results
v128_t scaled = wasm_i32x4_add(
wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
);
isum_vec = wasm_i32x4_add(isum_vec, scaled);
q8 += 32;
shift += 2;
}
q2 += 32;
// Horizontal sum of isum_vec
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
isum += wasm_i32x4_extract_lane(isum_vec, 0);
}
const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf += dall * isum - dmin * summs;
}
*s = sumf;
#elif defined __riscv_v_intrinsic #elif defined __riscv_v_intrinsic
float sumf = 0; float sumf = 0;
@ -5129,6 +5454,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#elif defined __wasm_simd128__
int8_t aux8[QK_K];
float sums[8] = {0};
uint32_t auxs[4];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
// Process blocks with SIMD
int8_t * a = aux8;
uint8_t m = 1;
for (int j = 0; j < QK_K; j += 128) {
for (int shift = 0; shift <= 6; shift += 2) {
v128_t v_m = wasm_i8x16_splat(m);
for (int l = 0; l < 32; l += 16) {
v128_t v_q3 = wasm_v128_load(q3 + l);
v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
v128_t v_hm = wasm_v128_load(hm + l);
v128_t v_mask = wasm_v128_and(v_hm, v_m);
v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
wasm_v128_store(a + l, v_low2);
}
a += 32;
m <<= 1;
}
q3 += 32;
}
// Extract scales
memcpy(auxs, x[i].scales, 12);
uint32_t tmp = auxs[2];
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
const int8_t * scales = (const int8_t *)auxs;
// SIMD dot product with register accumulators
v128_t v_acc0 = wasm_i32x4_splat(0);
v128_t v_acc1 = wasm_i32x4_splat(0);
a = aux8;
for (int j = 0; j < QK_K/16; ++j) {
const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
// Process 16 elements per iteration
for (int k = 0; k < 2; ++k) {
const v128_t v_q8 = wasm_i16x8_load8x8(q8);
const v128_t v_a = wasm_i16x8_load8x8(a);
v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
v_prod = wasm_i16x8_mul(v_prod, v_scale);
v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
q8 += 8;
a += 8;
}
}
// Accumulate results
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const v128_t v_d = wasm_f32x4_splat(d);
v128_t v_sum = wasm_f32x4_add(
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
);
// Accumulate into sums vector
wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
}
// Horizontal sum
v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
sumf = wasm_f32x4_extract_lane(v_sum, 0) +
wasm_f32x4_extract_lane(v_sum, 1) +
wasm_f32x4_extract_lane(v_sum, 2) +
wasm_f32x4_extract_lane(v_sum, 3);
*s = sumf;
#elif defined __riscv_v_intrinsic #elif defined __riscv_v_intrinsic
uint32_t aux[3]; uint32_t aux[3];
@ -5573,88 +5986,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
uint32_t utmp[4]; uint32_t utmp[4];
#ifdef __ARM_FEATURE_SVE #ifdef __ARM_NEON
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
uint32x2_t mins8 = { 0 };
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
svint32_t sumi1_1 = svdup_n_s32(0);
svint32_t sumi1_2 = svdup_n_s32(0);
svint32_t sumi2 = svdup_n_s32(0);
svint32_t sumi2_1 = svdup_n_s32(0);
svint32_t sumi2_2 = svdup_n_s32(0);
switch (vector_length) {
case 128:
{
for (int j = 0; j < QK_K/64; ++j) {
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4 += 32;
}
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
} break;
case 256:
case 512:
{
for (int j = 0; j < QK_K/64; ++j) {
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
}
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sumf;
#elif __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);
@ -5717,6 +6049,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = sumf; *s = sumf;
#elif defined __wasm_simd128__
const uint8_t * scales = (const uint8_t*)&utmp[0];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
// Process scales and mins
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
// Sum mins * q8sums
int32_t sumi = 0;
const int16_t * restrict q8sums = y[i].bsums;
const uint8_t * m = (const uint8_t *)&utmp[2];
for (int j = 0; j < 16; j += 2) {
sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
}
sumf -= dmin * sumi;
int32_t sumi1 = 0;
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
// Load 64 4-bit weights (32 bytes)
const v128_t q4x0 = wasm_v128_load(q4);
const v128_t q4x1 = wasm_v128_load(q4 + 16);
q4 += 32;
// Split into low/high nibbles
const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
// Load 64 8-bit values (64 bytes)
const v128_t q8x0 = wasm_v128_load(q8);
const v128_t q8x1 = wasm_v128_load(q8 + 16);
const v128_t q8x2 = wasm_v128_load(q8 + 32);
const v128_t q8x3 = wasm_v128_load(q8 + 48);
q8 += 64;
// Low nibble products
v128_t vacc1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4l0),
wasm_i16x8_extend_low_i8x16(q8x0)
);
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4l0),
wasm_i16x8_extend_high_i8x16(q8x0)
));
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4l1),
wasm_i16x8_extend_low_i8x16(q8x1)
));
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4l1),
wasm_i16x8_extend_high_i8x16(q8x1)
));
// High nibble products
v128_t vacc2 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4h0),
wasm_i16x8_extend_low_i8x16(q8x2)
);
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4h0),
wasm_i16x8_extend_high_i8x16(q8x2)
));
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4h1),
wasm_i16x8_extend_low_i8x16(q8x3)
));
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4h1),
wasm_i16x8_extend_high_i8x16(q8x3)
));
// Accumulate scaled results
int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
sumi1 += vacc1_sum * scales[2*j];
int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
sumi2 += vacc2_sum * scales[2*j+1];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
#elif defined __AVX2__ #elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m4 = _mm256_set1_epi8(0xF);
@ -6469,6 +6902,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc) + summs; *s = hsum_float_8(acc) + summs;
#elif defined __wasm_simd128__
//const uint8_t * scales = (const uint8_t*)&utmp[0];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
// Process scales and mins
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
// Sum mins * q8sums
int32_t sumi_mins = 0;
const int16_t * restrict q8sums = y[i].bsums;
const uint8_t * m = (const uint8_t *)&utmp[2];
for (int j = 0; j < 16; j += 2) {
sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
}
sumf -= dmin * sumi_mins; // Correct subtraction
v128_t qh0 = wasm_v128_load(qh);
v128_t qh1 = wasm_v128_load(qh + 16);
const uint8_t * sc = (const uint8_t *)utmp;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const int shift = j * 2;
v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
v128_t q5_0 = wasm_v128_load(q5);
v128_t q5_1 = wasm_v128_load(q5 + 16);
q5 += 32;
v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
v128_t q8_0 = wasm_v128_load(q8);
v128_t q8_1 = wasm_v128_load(q8 + 16);
v128_t q8_2 = wasm_v128_load(q8 + 32);
v128_t q8_3 = wasm_v128_load(q8 + 48);
q8 += 64;
// Process low quants
v128_t pl0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5l_0),
wasm_i16x8_extend_low_i8x16(q8_0)
);
pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5l_0),
wasm_i16x8_extend_high_i8x16(q8_0)
));
v128_t pl1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5l_1),
wasm_i16x8_extend_low_i8x16(q8_1)
);
pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5l_1),
wasm_i16x8_extend_high_i8x16(q8_1)
));
v128_t sum_low = wasm_i32x4_add(pl0, pl1);
// Process high quants
v128_t ph0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5h_0),
wasm_i16x8_extend_low_i8x16(q8_2)
);
ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5h_0),
wasm_i16x8_extend_high_i8x16(q8_2)
));
v128_t ph1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5h_1),
wasm_i16x8_extend_low_i8x16(q8_3)
);
ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5h_1),
wasm_i16x8_extend_high_i8x16(q8_3)
));
v128_t sum_high = wasm_i32x4_add(ph0, ph1);
// Accumulate with scale factors
int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
sumi += sl * sc[2*j] + sh * sc[2*j+1];
}
sumf += d * sumi;
}
*s = sumf;
#elif defined __riscv_v_intrinsic #elif defined __riscv_v_intrinsic
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
@ -7132,89 +7677,83 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#elif defined __riscv_v_intrinsic #elif defined __wasm_simd128__
int8_t aux8[QK_K] __attribute__((aligned(16)));
int32_t aux32[8] __attribute__((aligned(16))) = {0};
float sums[8] __attribute__((aligned(16))) = {0};
float sumf = 0;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
// Unpack 6-bit quantized data into aux8 (unchanged)
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh; const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs; int8_t * a = aux8;
for (int j = 0; j < QK_K; j += 128) {
const int8_t * restrict scale = x[i].scales; for (int l = 0; l < 32; ++l) {
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
size_t vl; a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
}
int sum_t = 0; a += 128;
int is = 0; q4 += 64;
qh += 32;
for (int j = 0; j < QK_K/128; ++j) {
vl = 32;
// load qh
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
// load Q6
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
// load Q8 and take product
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
vl = 16;
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
q6 += 64; qh += 32; q8 += 128; is=8;
} }
sumf += d * sum_t; const int8_t * restrict a_ptr = aux8;
const int8_t * restrict q8 = y[i].qs;
v128_t acc0 = wasm_i32x4_splat(0);
v128_t acc1 = wasm_i32x4_splat(0);
for (int j = 0; j < QK_K/16; ++j) {
const int scale = x[i].scales[j];
const v128_t vscale = wasm_i32x4_splat(scale);
// Load 16 elements from a and q8
const v128_t a_vec = wasm_v128_load(a_ptr);
const v128_t q8_vec = wasm_v128_load(q8);
// Process low 8 elements
v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
// Process high 8 elements
v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
// Scale and accumulate
prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
a_ptr += 16;
q8 += 16;
}
// Store accumulated results
wasm_v128_store(&aux32[0], acc0);
wasm_v128_store(&aux32[4], acc1);
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) {
sums[l] += d * aux32[l];
}
} }
// Sum final results
float sumf = 0;
for (int l = 0; l < 8; ++l) {
sumf += sums[l];
}
*s = sumf; *s = sumf;
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)