Q4_K dot product for ARM_NEON
This commit is contained in:
parent
54f808db2b
commit
a2533a72a3
1 changed files with 102 additions and 13 deletions
115
k_quants.c
115
k_quants.c
|
@ -5,7 +5,34 @@
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
|
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
||||||
|
//
|
||||||
|
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
||||||
|
//
|
||||||
|
#include <arm_neon.h>
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#ifdef __wasm_simd128__
|
||||||
|
#include <wasm_simd128.h>
|
||||||
|
#else
|
||||||
|
#ifdef __POWER9_VECTOR__
|
||||||
|
#include <altivec.h>
|
||||||
|
#undef bool
|
||||||
|
#define bool _Bool
|
||||||
|
#else
|
||||||
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
|
#include <intrin.h>
|
||||||
|
#else
|
||||||
|
#if !defined(__riscv)
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
#undef MAX
|
#undef MAX
|
||||||
|
@ -1047,23 +1074,90 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef z__ARM_NEON
|
|
||||||
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
|
|
||||||
#elif defined __AVX2__
|
|
||||||
|
|
||||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||||
static const uint32_t kmask3 = 0x03030303;
|
static const uint32_t kmask3 = 0x03030303;
|
||||||
|
|
||||||
|
uint32_t utmp[4];
|
||||||
|
|
||||||
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
|
const uint32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
int8x16x4_t q4bytes;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
||||||
|
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
||||||
|
|
||||||
|
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
||||||
|
|
||||||
|
memcpy(utmp, x[i].scales, 12);
|
||||||
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||||
|
const uint32_t uaux = utmp[1] & kmask1;
|
||||||
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||||
|
utmp[2] = uaux;
|
||||||
|
utmp[0] &= kmask1;
|
||||||
|
|
||||||
|
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
|
||||||
|
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
||||||
|
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
||||||
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
||||||
|
int32_t sumi_mins = vaddvq_s32(prod);
|
||||||
|
|
||||||
|
const uint8_t * scales = (const uint8_t *)utmp;
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
int32_t sumi = 0;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
|
||||||
|
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
|
||||||
|
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
q4bytes.val[2] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4bytes.val[3] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
|
||||||
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1])) * *scales++;
|
||||||
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q4bytes.val[2], q8bytes.val[2]), q4bytes.val[3], q8bytes.val[3])) * *scales++;
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
||||||
|
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += d * sumi - dmin * sumi_mins;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined __AVX2__
|
||||||
|
|
||||||
const __m256i m4 = _mm256_set1_epi8(0xF);
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
||||||
const __m128i mzero = _mm_setzero_si128();
|
const __m128i mzero = _mm_setzero_si128();
|
||||||
|
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
uint32_t utmp[4];
|
|
||||||
|
|
||||||
float summs = 0.f;
|
float summs = 0.f;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
@ -1124,11 +1218,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
|
||||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
|
||||||
static const uint32_t kmask3 = 0x03030303;
|
|
||||||
|
|
||||||
uint32_t utmp[4];
|
|
||||||
|
|
||||||
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
||||||
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue