ggml : initial ARM_NEON 2x F16 Q4_0 implementation
This commit is contained in:
parent
efd05648c8
commit
7da998da01
2 changed files with 234 additions and 102 deletions
9
Makefile
9
Makefile
|
@ -31,10 +31,15 @@ endif
|
|||
#
|
||||
|
||||
# keep standard at C11 and C++11
|
||||
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
CFLAGS = -I. -O3 -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
ifndef LLAMA_NO_NDEBUG
|
||||
CFLAGS += -DNDEBUG
|
||||
CXXFLAGS += -DNDEBUG
|
||||
endif
|
||||
|
||||
# warnings
|
||||
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
|
||||
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar
|
||||
|
|
299
ggml.c
299
ggml.c
|
@ -572,15 +572,18 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
|||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
ggml_fp16_t d0; // delta 0
|
||||
ggml_fp16_t d1; // delta 1
|
||||
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
static_assert(sizeof(block_q4_0) == 2*sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float m; // min
|
||||
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
@ -588,6 +591,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
|
|||
#define QK8_0 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
@ -596,26 +600,36 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
|
|||
// reference implementation for deterministic creation of model files
|
||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
assert(k % QK4_0 == 0);
|
||||
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
uint8_t pp[QK4_0/2];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float amax0 = 0.0f; // absolute max
|
||||
float amax1 = 0.0f; // absolute max
|
||||
|
||||
for (int l = 0; l < QK4_0; l++) {
|
||||
for (int l = 0; l < QK4_0/2; l++) {
|
||||
const float v = x[i*QK4_0 + l];
|
||||
amax = MAX(amax, fabsf(v));
|
||||
amax0 = MAX(amax0, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 3) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
for (int l = QK4_0/2; l < QK4_0; l++) {
|
||||
const float v = x[i*QK4_0 + l];
|
||||
amax1 = MAX(amax1, fabsf(v));
|
||||
}
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
const float v0 = x[i*QK4_0 + l + 0]*id;
|
||||
const float v1 = x[i*QK4_0 + l + 1]*id;
|
||||
const float d0 = amax0 / ((1 << 3) - 1);
|
||||
const float d1 = amax1 / ((1 << 3) - 1);
|
||||
|
||||
const float id0 = d0 ? 1.0f/d0 : 0.0f;
|
||||
const float id1 = d1 ? 1.0f/d1 : 0.0f;
|
||||
|
||||
y[i].d0 = GGML_FP32_TO_FP16(d0);
|
||||
y[i].d1 = GGML_FP32_TO_FP16(d1);
|
||||
|
||||
for (int l = 0; l < QK4_0/2; l += 2) {
|
||||
const float v0 = x[i*QK4_0 + l + 0]*id0;
|
||||
const float v1 = x[i*QK4_0 + l + 1]*id0;
|
||||
|
||||
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
||||
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
||||
|
@ -623,15 +637,27 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
|
||||
pp[l/2] = vi0 | (vi1 << 4);
|
||||
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
|
||||
memcpy(y[i].qs, pp, sizeof(pp));
|
||||
for (int l = QK4_0/2; l < QK4_0; l += 2) {
|
||||
const float v0 = x[i*QK4_0 + l + 0]*id1;
|
||||
const float v1 = x[i*QK4_0 + l + 1]*id1;
|
||||
|
||||
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
||||
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
||||
|
||||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
|
||||
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK4_0 == 0);
|
||||
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
block_q4_0 * restrict y = vy;
|
||||
|
@ -678,24 +704,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t srcv [8];
|
||||
float32x4_t asrcv[8];
|
||||
float32x4_t amaxv[8];
|
||||
|
||||
float32x4_t amaxv0[4];
|
||||
float32x4_t amaxv1[4];
|
||||
|
||||
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
||||
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
||||
|
||||
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
||||
for (int l = 0; l < 2; l++) amaxv0[2*l] = vmaxq_f32(asrcv [2*l], asrcv [2*l+1]);
|
||||
for (int l = 0; l < 1; l++) amaxv0[4*l] = vmaxq_f32(amaxv0[4*l], amaxv0[4*l+2]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
for (int l = 0; l < 2; l++) amaxv1[2*l] = vmaxq_f32(asrcv [4+2*l], asrcv [4+2*l+1]);
|
||||
for (int l = 0; l < 1; l++) amaxv1[4*l] = vmaxq_f32(amaxv1[4*l], amaxv1[4*l+2]);
|
||||
|
||||
const float d = amax / ((1 << 3) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
const float amax0 = vmaxvq_f32(amaxv0[0]);
|
||||
const float amax1 = vmaxvq_f32(amaxv1[0]);
|
||||
|
||||
y[i].d = d;
|
||||
const float d0 = amax0 / ((1 << 3) - 1);
|
||||
const float d1 = amax1 / ((1 << 3) - 1);
|
||||
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
||||
const float id0 = d0 ? 1.0f/d0 : 0.0f;
|
||||
const float id1 = d1 ? 1.0f/d1 : 0.0f;
|
||||
|
||||
y[i].d0 = d0;
|
||||
y[i].d1 = d1;
|
||||
|
||||
for (int l = 0; l < 4; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id0);
|
||||
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
||||
const int32x4_t vi = vcvtq_s32_f32(vf);
|
||||
|
||||
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
|
||||
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
|
||||
}
|
||||
|
||||
for (int l = 4; l < 8; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id1);
|
||||
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
||||
const int32x4_t vi = vcvtq_s32_f32(vf);
|
||||
|
||||
|
@ -1237,70 +1281,94 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
||||
const float32x4_t vd_0 = vdupq_n_f32(GGML_FP16_TO_FP32(x[i].d0));
|
||||
const float32x4_t vd_1 = vdupq_n_f32(GGML_FP16_TO_FP32(x[i].d1));
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 16) {
|
||||
// Load 16x4-bit integers into 8x8-bit integers
|
||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||
const uint8x8_t v8_0 = vld1_u8(pp + 0);
|
||||
const uint8x8_t v8_1 = vld1_u8(pp + 8);
|
||||
|
||||
// Expand 4-bit qs to 8-bit bytes
|
||||
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
|
||||
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
||||
const uint8x8_t v0_0 = vand_u8 (v8_0, vdup_n_u8(0x0f));
|
||||
const uint8x8_t v1_0 = vshr_n_u8(v8_0, 4);
|
||||
const uint8x8_t v0_1 = vand_u8 (v8_1, vdup_n_u8(0x0f));
|
||||
const uint8x8_t v1_1 = vshr_n_u8(v8_1, 4);
|
||||
|
||||
// Convert to signed 8-bit integers
|
||||
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
|
||||
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
||||
const int8x8_t vs_0_0 = vreinterpret_s8_u8(v0_0);
|
||||
const int8x8_t vs_1_0 = vreinterpret_s8_u8(v1_0);
|
||||
const int8x8_t vs_0_1 = vreinterpret_s8_u8(v0_1);
|
||||
const int8x8_t vs_1_1 = vreinterpret_s8_u8(v1_1);
|
||||
|
||||
// Subtract 8 from each byte
|
||||
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
|
||||
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
|
||||
const int8x8_t vb_0_0 = vsub_s8(vs_0_0, vdup_n_s8(8));
|
||||
const int8x8_t vb_1_0 = vsub_s8(vs_1_0, vdup_n_s8(8));
|
||||
const int8x8_t vb_0_1 = vsub_s8(vs_0_1, vdup_n_s8(8));
|
||||
const int8x8_t vb_1_1 = vsub_s8(vs_1_1, vdup_n_s8(8));
|
||||
|
||||
// Interleave and combine
|
||||
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
||||
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
||||
const int8x8_t vx_0_0 = vzip1_s8(vb_0_0, vb_1_0);
|
||||
const int8x8_t vx_1_0 = vzip2_s8(vb_0_0, vb_1_0);
|
||||
const int8x8_t vx_0_1 = vzip1_s8(vb_0_1, vb_1_1);
|
||||
const int8x8_t vx_1_1 = vzip2_s8(vb_0_1, vb_1_1);
|
||||
|
||||
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
|
||||
const int8x16_t vq_0 = vcombine_s8(vx_0_0, vx_1_0);
|
||||
const int8x16_t vq_1 = vcombine_s8(vx_0_1, vx_1_1);
|
||||
|
||||
// convert to 2x int16x8_t
|
||||
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
||||
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
||||
const int16x8_t vi_0_0 = vmovl_s8(vget_low_s8 (vq_0));
|
||||
const int16x8_t vi_1_0 = vmovl_s8(vget_high_s8(vq_0));
|
||||
const int16x8_t vi_0_1 = vmovl_s8(vget_low_s8 (vq_1));
|
||||
const int16x8_t vi_1_1 = vmovl_s8(vget_high_s8(vq_1));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
|
||||
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
||||
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
||||
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
||||
const float32x4_t vf_0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0_0)));
|
||||
const float32x4_t vf_1_0 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0_0)));
|
||||
const float32x4_t vf_2_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1_0)));
|
||||
const float32x4_t vf_3_0 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1_0)));
|
||||
const float32x4_t vf_0_1 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0_1)));
|
||||
const float32x4_t vf_1_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0_1)));
|
||||
const float32x4_t vf_2_1 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1_1)));
|
||||
const float32x4_t vf_3_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1_1)));
|
||||
|
||||
// Multiply by d
|
||||
const float32x4_t r0 = vmulq_f32(vf_0, vd);
|
||||
const float32x4_t r1 = vmulq_f32(vf_1, vd);
|
||||
const float32x4_t r2 = vmulq_f32(vf_2, vd);
|
||||
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
||||
const float32x4_t r0_0 = vmulq_f32(vf_0_0, vd_0);
|
||||
const float32x4_t r1_0 = vmulq_f32(vf_1_0, vd_0);
|
||||
const float32x4_t r2_0 = vmulq_f32(vf_2_0, vd_0);
|
||||
const float32x4_t r3_0 = vmulq_f32(vf_3_0, vd_0);
|
||||
const float32x4_t r0_1 = vmulq_f32(vf_0_1, vd_1);
|
||||
const float32x4_t r1_1 = vmulq_f32(vf_1_1, vd_1);
|
||||
const float32x4_t r2_1 = vmulq_f32(vf_2_1, vd_1);
|
||||
const float32x4_t r3_1 = vmulq_f32(vf_3_1, vd_1);
|
||||
|
||||
// Store
|
||||
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
||||
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
||||
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
||||
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
||||
}
|
||||
vst1q_f32(y + i*QK4_0 + 0, r0_0);
|
||||
vst1q_f32(y + i*QK4_0 + 4, r1_0);
|
||||
vst1q_f32(y + i*QK4_0 + 8, r2_0);
|
||||
vst1q_f32(y + i*QK4_0 + 12, r3_0);
|
||||
vst1q_f32(y + i*QK4_0 + 16, r0_1);
|
||||
vst1q_f32(y + i*QK4_0 + 20, r1_1);
|
||||
vst1q_f32(y + i*QK4_0 + 24, r2_1);
|
||||
vst1q_f32(y + i*QK4_0 + 28, r3_1);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = x[i].d;
|
||||
const float d0 = GGML_FP16_TO_FP32(x[i].d0);
|
||||
const float d1 = GGML_FP16_TO_FP32(x[i].d1);
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
for (int l = 0; l < QK4_0/2; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
|
||||
const int8_t vi0 = vi & 0xf;
|
||||
const int8_t vi1 = vi >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
const float v0 = (vi0 - 8)*d0;
|
||||
const float v1 = (vi1 - 8)*d0;
|
||||
|
||||
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
||||
|
||||
|
@ -1310,6 +1378,22 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|||
assert(!isnan(y[i*QK4_0 + l + 0]));
|
||||
assert(!isnan(y[i*QK4_0 + l + 1]));
|
||||
}
|
||||
|
||||
for (int l = QK4_0/2; l < QK4_0; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
|
||||
const int8_t vi0 = vi & 0xf;
|
||||
const int8_t vi1 = vi >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d1;
|
||||
const float v1 = (vi1 - 8)*d1;
|
||||
|
||||
y[i*QK4_0 + l + 0] = v0;
|
||||
y[i*QK4_0 + l + 1] = v1;
|
||||
|
||||
assert(!isnan(y[i*QK4_0 + l + 0]));
|
||||
assert(!isnan(y[i*QK4_0 + l + 1]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
@ -2250,14 +2334,19 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
//int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
//int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
|
||||
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
||||
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
||||
//p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
||||
//p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
||||
|
||||
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
//sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
//sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
|
||||
sum0 += (GGML_FP16_TO_FP32(x0->d0)*GGML_FP16_TO_FP32(y0->d0))*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls));
|
||||
sum0 += (GGML_FP16_TO_FP32(x0->d1)*GGML_FP16_TO_FP32(y0->d1))*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0hs, v1_0hs));
|
||||
sum1 += (GGML_FP16_TO_FP32(x1->d0)*GGML_FP16_TO_FP32(y1->d0))*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls));
|
||||
sum1 += (GGML_FP16_TO_FP32(x1->d1)*GGML_FP16_TO_FP32(y1->d1))*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1hs, v1_1hs));
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
|
@ -2517,14 +2606,13 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d0 = x[i].d;
|
||||
const float d1 = y[i].d;
|
||||
|
||||
const uint8_t * restrict p0 = x[i].qs;
|
||||
const uint8_t * restrict p1 = y[i].qs;
|
||||
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK4_0/2; j++) {
|
||||
int sumi_0 = 0;
|
||||
int sumi_1 = 0;
|
||||
|
||||
for (int j = 0; j < QK4_0/4; j++) {
|
||||
const uint8_t v0 = p0[j];
|
||||
const uint8_t v1 = p1[j];
|
||||
|
||||
|
@ -2534,9 +2622,24 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||
const int i2 = (v1 & 0xf) - 8;
|
||||
const int i3 = (v1 >> 4) - 8;
|
||||
|
||||
sumi += i0*i2 + i1*i3;
|
||||
sumi_0 += i0*i2 + i1*i3;
|
||||
}
|
||||
sumf += d0 * d1 * sumi;
|
||||
|
||||
for (int j = QK4_0/4; j < QK4_0/2; j++) {
|
||||
const uint8_t v0 = p0[j];
|
||||
const uint8_t v1 = p1[j];
|
||||
|
||||
const int i0 = (v0 & 0xf) - 8;
|
||||
const int i1 = (v0 >> 4) - 8;
|
||||
|
||||
const int i2 = (v1 & 0xf) - 8;
|
||||
const int i3 = (v1 >> 4) - 8;
|
||||
|
||||
sumi_1 += i0*i2 + i1*i3;
|
||||
}
|
||||
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d0) * GGML_FP16_TO_FP32(y[i].d0)) * sumi_0;
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d1) * GGML_FP16_TO_FP32(y[i].d1)) * sumi_1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -2765,6 +2868,12 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// interleave
|
||||
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
||||
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
||||
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
||||
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
||||
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
||||
|
@ -2772,21 +2881,26 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
||||
|
||||
// interleave
|
||||
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
||||
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
||||
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
||||
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
||||
//const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
||||
//const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
||||
//const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
||||
//const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
//int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
//int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
|
||||
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
||||
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
||||
//p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
||||
//p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
||||
|
||||
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
//sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
//sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
|
||||
sum0 += (GGML_FP16_TO_FP32(x0->d0)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l));
|
||||
sum0 += (GGML_FP16_TO_FP32(x0->d1)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h));
|
||||
sum1 += (GGML_FP16_TO_FP32(x1->d0)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l));
|
||||
sum1 += (GGML_FP16_TO_FP32(x1->d1)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h));
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
|
@ -2904,14 +3018,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d0 = x[i].d;
|
||||
const float d1 = y[i].d;
|
||||
|
||||
const uint8_t * restrict p0 = x[i].qs;
|
||||
const int8_t * restrict p1 = y[i].qs;
|
||||
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < QK8_0/2; j++) {
|
||||
int sumi_0 = 0;
|
||||
int sumi_1 = 0;
|
||||
|
||||
for (int j = 0; j < QK8_0/4; j++) {
|
||||
const uint8_t v0 = p0[j];
|
||||
|
||||
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
||||
|
@ -2920,9 +3033,23 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||
const int i2 = p1[2*j + 0];
|
||||
const int i3 = p1[2*j + 1];
|
||||
|
||||
sumi += i0*i2 + i1*i3;
|
||||
sumi_0 += i0*i2 + i1*i3;
|
||||
}
|
||||
sumf += d0*d1*sumi;
|
||||
|
||||
for (int j = QK8_0/4; j < QK8_0/2; j++) {
|
||||
const uint8_t v0 = p0[j];
|
||||
|
||||
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
||||
const int i1 = (int8_t) (v0 >> 4) - 8;
|
||||
|
||||
const int i2 = p1[2*j + 0];
|
||||
const int i3 = p1[2*j + 1];
|
||||
|
||||
sumi_1 += i0*i2 + i1*i3;
|
||||
}
|
||||
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d0) * y[i].d) * sumi_0;
|
||||
sumf += (GGML_FP16_TO_FP32(x[i].d1) * y[i].d) * sumi_1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue