q4_0c: Arm Neon acceleration
Mostly copied from the q4_0 implementation
This commit is contained in:
parent
ab543dc1a4
commit
1b49d26f8a
1 changed files with 94 additions and 2 deletions
96
ggml.c
96
ggml.c
|
@ -1758,7 +1758,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
|
|||
int8_t * restrict qs = vy;
|
||||
float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C);
|
||||
|
||||
#if __AVX512F__
|
||||
#if defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t srcv [8];
|
||||
float32x4_t asrcv[8];
|
||||
float32x4_t amaxv[8];
|
||||
|
||||
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
||||
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
||||
|
||||
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
ds[i] = d;
|
||||
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
||||
const int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
|
||||
qs[i*QK8_0C + 4*l + 0] = vgetq_lane_s32(vi, 0);
|
||||
qs[i*QK8_0C + 4*l + 1] = vgetq_lane_s32(vi, 1);
|
||||
qs[i*QK8_0C + 4*l + 2] = vgetq_lane_s32(vi, 2);
|
||||
qs[i*QK8_0C + 4*l + 3] = vgetq_lane_s32(vi, 3);
|
||||
}
|
||||
}
|
||||
#elif defined(__AVX512F__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C );
|
||||
const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2);
|
||||
|
@ -3095,7 +3125,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
|
|||
|
||||
float sumf = 0.0;
|
||||
|
||||
#if __AVX512F__
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb/2; i++) {
|
||||
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
|
||||
const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ...
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const uint8x16_t v0_01l = vld1q_u8(&xqs[i*QK4_0]);
|
||||
const uint8x16_t v0_01h = vld1q_u8(&xqs[i*QK4_0 + QK4_0/2]);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_01l, m4b));
|
||||
const int8x16_t v0_0h = vreinterpretq_s8_u8(vandq_u8 (v0_01h, m4b));
|
||||
const int8x16_t v0_1l = vreinterpretq_s8_u8(vshrq_n_u8(v0_01l, 4));
|
||||
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_01h, 4));
|
||||
|
||||
// sub 8
|
||||
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0l = vld1q_s8(&yqs[dst0*QK8_0C]);
|
||||
const int8x16_t v1_0h = vld1q_s8(&yqs[dst0*QK8_0C + 16]);
|
||||
const int8x16_t v1_1l = vld1q_s8(&yqs[dst1*QK8_0C]);
|
||||
const int8x16_t v1_1h = vld1q_s8(&yqs[dst1*QK8_0C + 16]);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), xds[dst0]*yds[dst0]);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), xds[dst1]*yds[dst1]);
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), xds[dst0]*yds[dst0]);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), xds[dst1]*yds[dst1]);
|
||||
#endif
|
||||
}
|
||||
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
|
||||
#elif defined(__AVX512F__)
|
||||
// Initialize accumulator with zeros
|
||||
__m512 acc = _mm512_setzero_ps();
|
||||
for (int i = 0; i < nb; i += 4) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue