ggml : fix q4_0
This commit is contained in:
parent
3f68842e1c
commit
79b95e3420
1 changed files with 81 additions and 94 deletions
|
@ -3808,11 +3808,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32(sumv2));
|
vst1_f32(s, vget_low_f32(sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
int ib = 0;
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
if (svcntb() == QK8_0) {
|
if (svcntb() == QK8_0) {
|
||||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||||
|
@ -3821,13 +3825,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const block_q4_0 * restrict x0 = &x[ib + 0];
|
||||||
for (int i = 0; i < nb; i += 2) {
|
const block_q4_0 * restrict x1 = &x[ib + 1];
|
||||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
const block_q8_0 * restrict y0 = &y[ib + 0];
|
||||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
const block_q8_0 * restrict y1 = &y[ib + 1];
|
||||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
||||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
||||||
|
|
||||||
// load x
|
// load x
|
||||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||||
|
@ -3850,21 +3852,17 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
#endif
|
#elif defined(__ARM_NEON)
|
||||||
#if defined(__ARM_NEON)
|
|
||||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const block_q4_0 * restrict x0 = &x[ib + 0];
|
||||||
for (int i = 0; i < nb; i += 2) {
|
const block_q4_0 * restrict x1 = &x[ib + 1];
|
||||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
const block_q8_0 * restrict y0 = &y[ib + 0];
|
||||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
const block_q8_0 * restrict y1 = &y[ib + 1];
|
||||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
|
||||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||||
|
@ -3898,23 +3896,23 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||||
#elif defined(__AVX2__)
|
#elif defined(__AVX2__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (; ib < nb; ++ib) {
|
||||||
/* Compute combined scale for the block */
|
/* Compute combined scale for the block */
|
||||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
||||||
|
|
||||||
__m256i qx = bytes_from_nibbles_32(x[i].qs);
|
__m256i qx = bytes_from_nibbles_32(x[ib].qs);
|
||||||
|
|
||||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||||
const __m256i off = _mm256_set1_epi8( 8 );
|
const __m256i off = _mm256_set1_epi8( 8 );
|
||||||
qx = _mm256_sub_epi8( qx, off );
|
qx = _mm256_sub_epi8( qx, off );
|
||||||
|
|
||||||
__m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
__m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
|
||||||
|
|
||||||
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
||||||
|
|
||||||
|
@ -3922,28 +3920,28 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc = _mm256_fmadd_ps( d, q, acc );
|
acc = _mm256_fmadd_ps( d, q, acc );
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
sumf = hsum_float_8(acc);
|
||||||
#elif defined(__AVX__)
|
#elif defined(__AVX__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (; ib < nb; ++ib) {
|
||||||
// Compute combined scale for the block
|
// Compute combined scale for the block
|
||||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
||||||
|
|
||||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||||
const __m128i off = _mm_set1_epi8(8);
|
const __m128i off = _mm_set1_epi8(8);
|
||||||
|
|
||||||
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
|
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
|
||||||
|
|
||||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
||||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
|
||||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||||
by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
|
||||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
|
@ -3954,7 +3952,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
sumf = hsum_float_8(acc);
|
||||||
#elif defined(__SSSE3__)
|
#elif defined(__SSSE3__)
|
||||||
// set constants
|
// set constants
|
||||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||||
|
@ -4017,43 +4015,41 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc_3 = _mm_mul_ps( d_2_3, p3 );
|
acc_3 = _mm_mul_ps( d_2_3, p3 );
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 2; i < nb; i+=2) {
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
|
_mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
|
_mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
// Compute combined scale for the block 0 and 1
|
// Compute combined scale for the block 0 and 1
|
||||||
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
||||||
|
|
||||||
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
|
||||||
|
|
||||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
||||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
|
||||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
||||||
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
|
||||||
bx_1 = _mm_sub_epi8(bx_1, off);
|
bx_1 = _mm_sub_epi8(bx_1, off);
|
||||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||||
|
|
||||||
_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
// Compute combined scale for the block 2 and 3
|
// Compute combined scale for the block 2 and 3
|
||||||
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
|
const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
|
||||||
|
|
||||||
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
|
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
|
||||||
|
|
||||||
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
||||||
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
|
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
|
||||||
bx_2 = _mm_sub_epi8(bx_2, off);
|
bx_2 = _mm_sub_epi8(bx_2, off);
|
||||||
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||||
|
|
||||||
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
||||||
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
|
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
|
||||||
bx_3 = _mm_sub_epi8(bx_3, off);
|
bx_3 = _mm_sub_epi8(bx_3, off);
|
||||||
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||||
|
|
||||||
|
@ -4076,18 +4072,18 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc_3 = _mm_add_ps(p3_d, acc_3);
|
acc_3 = _mm_add_ps(p3_d, acc_3);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||||
#elif defined(__riscv_v_intrinsic)
|
#elif defined(__riscv_v_intrinsic)
|
||||||
float sumf = 0.0;
|
float sumf = 0.0;
|
||||||
|
|
||||||
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (; ib < nb; ++ib) {
|
||||||
// load elements
|
// load elements
|
||||||
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
|
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
|
||||||
|
|
||||||
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
|
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
|
||||||
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
|
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
|
||||||
|
|
||||||
// mask and store lower part of x, and then upper part
|
// mask and store lower part of x, and then upper part
|
||||||
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
||||||
|
@ -4110,11 +4106,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
||||||
|
|
||||||
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||||
const vector signed int v0 = vec_splats((int32_t)0);
|
const vector signed int v0 = vec_splats((int32_t)0);
|
||||||
|
@ -4124,17 +4118,17 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
vector float vsumf0 = vec_splats(0.0f);
|
vector float vsumf0 = vec_splats(0.0f);
|
||||||
|
|
||||||
#pragma GCC unroll 8
|
#pragma GCC unroll 8
|
||||||
for (int i = 0; i < nb; i++) {
|
for (; ib < nb; ++ib) {
|
||||||
__builtin_prefetch(x[i].qs, 0, 1);
|
__builtin_prefetch(x[ib].qs, 0, 1);
|
||||||
__builtin_prefetch(y[i].qs, 0, 1);
|
__builtin_prefetch(y[ib].qs, 0, 1);
|
||||||
|
|
||||||
vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
|
vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
|
||||||
vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
|
vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
|
||||||
vector float vd = vec_mul(vxd, vyd);
|
vector float vd = vec_mul(vxd, vyd);
|
||||||
|
|
||||||
vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
|
vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
|
||||||
vector signed char q8y0 = vec_xl( 0, y[i].qs);
|
vector signed char q8y0 = vec_xl( 0, y[ib].qs);
|
||||||
vector signed char q8y1 = vec_xl(16, y[i].qs);
|
vector signed char q8y1 = vec_xl(16, y[ib].qs);
|
||||||
|
|
||||||
vector signed char q4x0 = vec_and(qxs, lowMask);
|
vector signed char q4x0 = vec_and(qxs, lowMask);
|
||||||
vector signed char q4x1 = vec_sr(qxs, v4);
|
vector signed char q4x1 = vec_sr(qxs, v4);
|
||||||
|
@ -4156,24 +4150,24 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
|
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
|
||||||
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
|
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
|
||||||
|
|
||||||
*s = vec_extract(vsumf0, 0);
|
sumf = vec_extract(vsumf0, 0);
|
||||||
|
|
||||||
#elif defined(__loongarch_asx)
|
#elif defined(__loongarch_asx)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m256 acc = (__m256)__lasx_xvldi(0);
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (; ib < nb; ++ib) {
|
||||||
/* Compute combined scale for the block */
|
/* Compute combined scale for the block */
|
||||||
const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
||||||
|
|
||||||
__m256i qx = bytes_from_nibbles_32(x[i].qs);
|
__m256i qx = bytes_from_nibbles_32(x[ib].qs);
|
||||||
|
|
||||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||||
const __m256i off = __lasx_xvreplgr2vr_b( 8 );
|
const __m256i off = __lasx_xvreplgr2vr_b( 8 );
|
||||||
qx = __lasx_xvsub_b( qx, off );
|
qx = __lasx_xvsub_b( qx, off );
|
||||||
|
|
||||||
__m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0);
|
__m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
|
||||||
|
|
||||||
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
const __m256 q = mul_sum_i8_pairs_float(qx, qy);
|
||||||
|
|
||||||
|
@ -4181,7 +4175,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc = __lasx_xvfmadd_s( d, q, acc );
|
acc = __lasx_xvfmadd_s( d, q, acc );
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
sumf = hsum_float_8(acc);
|
||||||
#elif defined(__loongarch_sx)
|
#elif defined(__loongarch_sx)
|
||||||
// set constants
|
// set constants
|
||||||
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
|
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
|
||||||
|
@ -4241,41 +4235,39 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc_3 = __lsx_vfmul_s( d_2_3, p3 );
|
acc_3 = __lsx_vfmul_s( d_2_3, p3 );
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 2; i < nb; i+=2) {
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
|
||||||
// Compute combined scale for the block 0 and 1
|
// Compute combined scale for the block 0 and 1
|
||||||
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
|
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
||||||
|
|
||||||
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0);
|
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
||||||
|
|
||||||
__m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
|
__m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
|
||||||
__m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0);
|
__m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
|
||||||
bx_0 = __lsx_vsub_b(bx_0, off);
|
bx_0 = __lsx_vsub_b(bx_0, off);
|
||||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
__m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
|
__m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
|
||||||
__m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0);
|
__m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
|
||||||
bx_1 = __lsx_vsub_b(bx_1, off);
|
bx_1 = __lsx_vsub_b(bx_1, off);
|
||||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||||
|
|
||||||
//_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
//_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
//_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
// Compute combined scale for the block 2 and 3
|
// Compute combined scale for the block 2 and 3
|
||||||
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
|
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
|
||||||
|
|
||||||
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0);
|
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
||||||
|
|
||||||
__m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
|
__m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
|
||||||
__m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0);
|
__m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
|
||||||
bx_2 = __lsx_vsub_b(bx_2, off);
|
bx_2 = __lsx_vsub_b(bx_2, off);
|
||||||
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||||
|
|
||||||
__m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
|
__m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
|
||||||
__m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0);
|
__m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
|
||||||
bx_3 = __lsx_vsub_b(bx_3, off);
|
bx_3 = __lsx_vsub_b(bx_3, off);
|
||||||
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||||
|
|
||||||
|
@ -4298,27 +4290,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
acc_3 = __lsx_vfadd_s(p3_d, acc_3);
|
acc_3 = __lsx_vfadd_s(p3_d, acc_3);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||||
|
#endif
|
||||||
#else
|
for (; ib < nb; ++ib) {
|
||||||
// scalar
|
|
||||||
float sumf = 0.0;
|
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
|
||||||
int sumi = 0;
|
int sumi = 0;
|
||||||
|
|
||||||
for (int j = 0; j < qk/2; ++j) {
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
const int v0 = (x[i].qs[j] & 0x0F) - 8;
|
const int v0 = (x[ib].qs[j] & 0x0F) - 8;
|
||||||
const int v1 = (x[i].qs[j] >> 4) - 8;
|
const int v1 = (x[ib].qs[j] >> 4) - 8;
|
||||||
|
|
||||||
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
|
sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
|
sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
|
@ -4404,7 +4391,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
sumv2 = vaddq_f32(sumv2, summs0);
|
sumv2 = vaddq_f32(sumv2, summs0);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32(sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue