q4_0c: Arm Neon acceleration

Mostly copied from the q4_0 implementation
This commit is contained in:
Håkon H. Hitland 2023-04-21 00:11:49 +02:00
parent ab543dc1a4
commit 1b49d26f8a

96
ggml.c
View file

@ -1758,7 +1758,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
int8_t * restrict qs = vy;
float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C);
#if __AVX512F__
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
ds[i] = d;
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
qs[i*QK8_0C + 4*l + 0] = vgetq_lane_s32(vi, 0);
qs[i*QK8_0C + 4*l + 1] = vgetq_lane_s32(vi, 1);
qs[i*QK8_0C + 4*l + 2] = vgetq_lane_s32(vi, 2);
qs[i*QK8_0C + 4*l + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__AVX512F__)
for (int i = 0; i < nb; i++) {
const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C );
const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2);
@ -3095,7 +3125,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
float sumf = 0.0;
#if __AVX512F__
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb/2; i++) {
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ...
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_01l = vld1q_u8(&xqs[i*QK4_0]);
const uint8x16_t v0_01h = vld1q_u8(&xqs[i*QK4_0 + QK4_0/2]);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_01l, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vandq_u8 (v0_01h, m4b));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vshrq_n_u8(v0_01l, 4));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_01h, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
const int8x16_t v1_0l = vld1q_s8(&yqs[dst0*QK8_0C]);
const int8x16_t v1_0h = vld1q_s8(&yqs[dst0*QK8_0C + 16]);
const int8x16_t v1_1l = vld1q_s8(&yqs[dst1*QK8_0C]);
const int8x16_t v1_1h = vld1q_s8(&yqs[dst1*QK8_0C + 16]);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), xds[dst0]*yds[dst0]);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), xds[dst1]*yds[dst1]);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), xds[dst0]*yds[dst0]);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), xds[dst1]*yds[dst1]);
#endif
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX512F__)
// Initialize accumulator with zeros
__m512 acc = _mm512_setzero_ps();
for (int i = 0; i < nb; i += 4) {