Arm AArch64: minor code refactoring for resolving a build issue with cmake

This commit is contained in:
Dibakar Gope 2024-05-16 12:15:48 +00:00 committed by Dibakar Gope
parent 8ee6779147
commit a657246d62
3 changed files with 547 additions and 800 deletions

View file

@ -1,4 +1,8 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@ -315,12 +319,18 @@ inline int64_t roundup(const int64_t a, const int64_t b) {
}
}
void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
void ggml_gemv_q4_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
UNUSED(n);
UNUSED(s);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(ith);
UNUSED(nth);
#if defined(__ARM_FEATURE_SVE)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemv_q4_0_q8_0_aarch64_neon(n, s, vx, vy, nr, nc, ith, nth);
return;
}
if (svcntw() == 8) {
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
size_t width = xend - x0;
@ -393,12 +403,10 @@ void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const vo
: [a_ptr] "r" (a_ptr), [num_blocks] "r" (num_blocks)
: "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
);
#endif
return;
}
void ggml_gemv_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
UNUSED(nr);
#if defined(__ARM_NEON)
#endif
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
size_t width = xend - x0;
@ -470,12 +478,7 @@ void ggml_gemv_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void
: [a_ptr] "r" (a_ptr), [num_blocks] "r" (num_blocks)
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
);
#endif
}
void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
UNUSED(nr);
#if defined(__ARM_NEON)
#elif defined(__ARM_NEON)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
size_t width = xend - x0;
@ -545,168 +548,18 @@ void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, con
#endif
}
void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
void ggml_gemm_q4_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
UNUSED(n);
UNUSED(s);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(ith);
UNUSED(nth);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const svbool_t ptrue = svptrue_b8();
const block_q8_0x8 * b_ptr_start = (const block_q8_0x8 *) vx;
const block_q8_0 * a_ptr_start = (const block_q8_0 *) vy;
for (int64_t y = 0; y < nr; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
svfloat32_t acc_row = svdup_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const svint8_t rhs_vec_0_0_0 = svld1_s8(ptrue, b_ptr[b].qs);
const svint8_t rhs_vec_0_1_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 1);
const svint8_t rhs_vec_0_2_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 2);
const svint8_t rhs_vec_0_3_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 3);
const svint8_t rhs_vec_0_0_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 4);
const svint8_t rhs_vec_0_1_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 5);
const svint8_t rhs_vec_0_2_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 6);
const svint8_t rhs_vec_0_3_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 7);
// Scale values
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
svint32_t iacc = svdup_s32(0);
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
}
svst1(ptrue, s + (y * nc + x * 8), acc_row);
}
}
#endif
}
void ggml_gemv_q8_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#if defined(__ARM_NEON)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const block_q8_0x8 * b_ptr_start = (const block_q8_0x8 *) vx;
const block_q8_0 * a_ptr_start = (const block_q8_0 *) vy;
for (int64_t y = 0; y < nr; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
float32x4_t acc_row[2];
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const int8x16_t rhs_vec_0_0_0 = vld1q_s8(b_ptr[b].qs);
const int8x16_t rhs_vec_1_0_0 = vld1q_s8(b_ptr[b].qs + 16);
const int8x16_t rhs_vec_0_1_0 = vld1q_s8(b_ptr[b].qs + 32);
const int8x16_t rhs_vec_1_1_0 = vld1q_s8(b_ptr[b].qs + 48);
const int8x16_t rhs_vec_0_2_0 = vld1q_s8(b_ptr[b].qs + 64);
const int8x16_t rhs_vec_1_2_0 = vld1q_s8(b_ptr[b].qs + 80);
const int8x16_t rhs_vec_0_3_0 = vld1q_s8(b_ptr[b].qs + 96);
const int8x16_t rhs_vec_1_3_0 = vld1q_s8(b_ptr[b].qs + 112);
const int8x16_t rhs_vec_0_0_1 = vld1q_s8(b_ptr[b].qs + 128);
const int8x16_t rhs_vec_1_0_1 = vld1q_s8(b_ptr[b].qs + 144);
const int8x16_t rhs_vec_0_1_1 = vld1q_s8(b_ptr[b].qs + 160);
const int8x16_t rhs_vec_1_1_1 = vld1q_s8(b_ptr[b].qs + 176);
const int8x16_t rhs_vec_0_2_1 = vld1q_s8(b_ptr[b].qs + 192);
const int8x16_t rhs_vec_1_2_1 = vld1q_s8(b_ptr[b].qs + 208);
const int8x16_t rhs_vec_0_3_1 = vld1q_s8(b_ptr[b].qs + 224);
const int8x16_t rhs_vec_1_3_1 = vld1q_s8(b_ptr[b].qs + 240);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x8_t col_scale_f16 = vld1q_f16((const ggml_fp16_internal_t *)(b_ptr[b].d));
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
const float16x4_t row_scale_f16 = vld1_dup_f16((const ggml_fp16_internal_t *)(&(a_ptr[b].d)));
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
int32x4_t iacc0 = vdupq_n_s32(0);
int32x4_t iacc1 = vdupq_n_s32(0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
}
vst1q_f32(s + (y * nc + x * 8), acc_row[0]);
vst1q_f32(s + (y * nc + x * 8 + 4), acc_row[1]);
}
}
#endif
}
void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemm_q4_0_q8_0_aarch64_neon(n, s, vx, vy, nr, nc, ith, nth);
return;
}
if (svcntw() == 8) {
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
size_t width = xend - x0;
@ -1124,10 +977,9 @@ void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const vo
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [num_blocks] "r" (num_blocks), [res_stride] "r" (res_stride), [width] "r" (width)
: "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
);
#endif
return;
}
void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#endif
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
@ -1534,11 +1386,7 @@ void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [num_blocks] "r" (num_blocks), [res_stride] "r" (res_stride), [width] "r" (width)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
);
#endif
}
void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#if defined(__ARM_NEON)
#elif defined(__ARM_NEON)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
size_t width = xend - x0;
@ -2006,94 +1854,3 @@ void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, con
);
#endif
}
void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const block_q8_0x4 * b_ptr_start = (const block_q8_0x4 *) vx;
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *) vy;
for (int64_t y = 0; y < nr / 4; y += nr / 4) {
for (int64_t x = x0 / 4; x < xend / 4; x++) {
const block_q8_0x4 ** a_ptrs = new const block_q8_0x4 * [nr / 4];
a_ptrs[0] = a_ptr_start + (y * a_nb);
for (int i = 0; i < (nr / 4) - 1; i++) {
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
}
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulators
float32x4_t * acc_rows = new float32x4_t[nr];
for (int i = 0; i < nr; i++) {
acc_rows[i] = vdupq_n_f32(0.0f);
}
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x4_t col_scale_f16 = vld1_f16((const ggml_fp16_internal_t *)(b_ptr[b].d));
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
// Process LHS in pairs of rows
for (int rp = 0; rp < nr / 4; rp++) {
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
// Do the MMLAs into 2x2 matrices
const int32x4_t iacc_mat_00 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
const int32x4_t iacc_mat_01 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
const int32x4_t iacc_mat_10 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
const int32x4_t iacc_mat_11 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
// Straighten out to make 4 row vectors
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const float16x4_t row_scale_f16 = vld1_f16((const ggml_fp16_internal_t *)(a_ptrs[rp][b].d));
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
}
}
for (int i = 0; i < nr; i++) {
vst1q_f32(s + ((y * 4 + i) * nc + x * 4), acc_rows[i]);
}
delete [] acc_rows;
delete [] a_ptrs;
}
}
#endif
}

View file

@ -24,17 +24,10 @@ block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int bloc
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len);
// GEMV
void ggml_gemv_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemv_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemv_q8_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemv_q8_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemv_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
// GEMM
void ggml_gemm_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemm_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemm_q8_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
void ggml_gemm_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
#ifdef __cplusplus
}

View file

@ -38,7 +38,7 @@
#include <unistd.h>
#endif
#ifdef __ARM_FEATURE_MATMUL_INT8
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
#undef GGML_USE_LLAMAFILE
#endif
@ -915,16 +915,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
#if defined(__ARM_FEATURE_SVE)
.gemv = ggml_gemv_q4_0_q8_0_aarch64_sve256,
.gemm = ggml_gemm_q4_0_q8_0_aarch64_sve256,
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon,
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon,
#elif defined(__ARM_NEON)
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm,
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm,
#endif
.gemv = ggml_gemv_q4_0_q8_0_aarch64,
.gemm = ggml_gemm_q4_0_q8_0_aarch64,
}
};
@ -12242,7 +12234,7 @@ UseGgmlGemm1:;
}
}
}
if ((type == GGML_TYPE_Q4_0_AARCH64) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
if (from_float_to_mat && gemm && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) {
from_float_to_mat((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4, ggml_cpu_has_matmul_int8() ? 8 : 4);
wdata += row_size * 4;
@ -12340,10 +12332,9 @@ UseGgmlGemm2:;
//if (ith == 0)
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (type == GGML_TYPE_Q4_0_AARCH64)) {
gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 2) && (type == GGML_TYPE_Q4_0_AARCH64)) {
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
if (ne11 == 1) gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
else {
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
}
@ -12357,7 +12348,13 @@ UseGgmlGemm2:;
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
}
}
}
else if ((ggml_n_dims(src0) == 2) && gemv) {
for (int row_iter = 0; row_iter < ne11; row_iter++) {
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
}
}
else {