Arm AArch64: minor code refactoring for resolving a build issue with cmake
This commit is contained in:
parent
8ee6779147
commit
a657246d62
3 changed files with 547 additions and 800 deletions
307
ggml-aarch64.cpp
307
ggml-aarch64.cpp
|
@ -1,4 +1,8 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
|
@ -315,12 +319,18 @@ inline int64_t roundup(const int64_t a, const int64_t b) {
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
void ggml_gemv_q4_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
UNUSED(n);
|
||||
UNUSED(s);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(ith);
|
||||
UNUSED(nth);
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
if (svcntw() != 8) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemv_q4_0_q8_0_aarch64_neon(n, s, vx, vy, nr, nc, ith, nth);
|
||||
return;
|
||||
}
|
||||
if (svcntw() == 8) {
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
|
||||
size_t width = xend - x0;
|
||||
|
@ -393,12 +403,10 @@ void ggml_gemv_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const vo
|
|||
: [a_ptr] "r" (a_ptr), [num_blocks] "r" (num_blocks)
|
||||
: "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
|
||||
);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
UNUSED(nr);
|
||||
#if defined(__ARM_NEON)
|
||||
#endif
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
|
||||
size_t width = xend - x0;
|
||||
|
@ -470,12 +478,7 @@ void ggml_gemv_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void
|
|||
: [a_ptr] "r" (a_ptr), [num_blocks] "r" (num_blocks)
|
||||
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
UNUSED(nr);
|
||||
#if defined(__ARM_NEON)
|
||||
#elif defined(__ARM_NEON)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
|
||||
size_t width = xend - x0;
|
||||
|
@ -545,168 +548,18 @@ void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, con
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
UNUSED(n);
|
||||
UNUSED(s);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(ith);
|
||||
UNUSED(nth);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
const block_q8_0x8 * b_ptr_start = (const block_q8_0x8 *) vx;
|
||||
const block_q8_0 * a_ptr_start = (const block_q8_0 *) vy;
|
||||
|
||||
for (int64_t y = 0; y < nr; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulator
|
||||
svfloat32_t acc_row = svdup_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const svint8_t rhs_vec_0_0_0 = svld1_s8(ptrue, b_ptr[b].qs);
|
||||
const svint8_t rhs_vec_0_1_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 1);
|
||||
const svint8_t rhs_vec_0_2_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 2);
|
||||
const svint8_t rhs_vec_0_3_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 3);
|
||||
const svint8_t rhs_vec_0_0_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 4);
|
||||
const svint8_t rhs_vec_0_1_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 5);
|
||||
const svint8_t rhs_vec_0_2_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 6);
|
||||
const svint8_t rhs_vec_0_3_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 7);
|
||||
|
||||
// Scale values
|
||||
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
|
||||
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
|
||||
|
||||
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
|
||||
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
|
||||
|
||||
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
|
||||
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
|
||||
|
||||
svint32_t iacc = svdup_s32(0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
|
||||
}
|
||||
|
||||
svst1(ptrue, s + (y * nc + x * 8), acc_row);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#if defined(__ARM_NEON)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const block_q8_0x8 * b_ptr_start = (const block_q8_0x8 *) vx;
|
||||
const block_q8_0 * a_ptr_start = (const block_q8_0 *) vy;
|
||||
|
||||
for (int64_t y = 0; y < nr; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
// Master FP accumulator
|
||||
float32x4_t acc_row[2];
|
||||
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const int8x16_t rhs_vec_0_0_0 = vld1q_s8(b_ptr[b].qs);
|
||||
const int8x16_t rhs_vec_1_0_0 = vld1q_s8(b_ptr[b].qs + 16);
|
||||
const int8x16_t rhs_vec_0_1_0 = vld1q_s8(b_ptr[b].qs + 32);
|
||||
const int8x16_t rhs_vec_1_1_0 = vld1q_s8(b_ptr[b].qs + 48);
|
||||
const int8x16_t rhs_vec_0_2_0 = vld1q_s8(b_ptr[b].qs + 64);
|
||||
const int8x16_t rhs_vec_1_2_0 = vld1q_s8(b_ptr[b].qs + 80);
|
||||
const int8x16_t rhs_vec_0_3_0 = vld1q_s8(b_ptr[b].qs + 96);
|
||||
const int8x16_t rhs_vec_1_3_0 = vld1q_s8(b_ptr[b].qs + 112);
|
||||
const int8x16_t rhs_vec_0_0_1 = vld1q_s8(b_ptr[b].qs + 128);
|
||||
const int8x16_t rhs_vec_1_0_1 = vld1q_s8(b_ptr[b].qs + 144);
|
||||
const int8x16_t rhs_vec_0_1_1 = vld1q_s8(b_ptr[b].qs + 160);
|
||||
const int8x16_t rhs_vec_1_1_1 = vld1q_s8(b_ptr[b].qs + 176);
|
||||
const int8x16_t rhs_vec_0_2_1 = vld1q_s8(b_ptr[b].qs + 192);
|
||||
const int8x16_t rhs_vec_1_2_1 = vld1q_s8(b_ptr[b].qs + 208);
|
||||
const int8x16_t rhs_vec_0_3_1 = vld1q_s8(b_ptr[b].qs + 224);
|
||||
const int8x16_t rhs_vec_1_3_1 = vld1q_s8(b_ptr[b].qs + 240);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x8_t col_scale_f16 = vld1q_f16((const ggml_fp16_internal_t *)(b_ptr[b].d));
|
||||
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
|
||||
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_dup_f16((const ggml_fp16_internal_t *)(&(a_ptr[b].d)));
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
|
||||
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
|
||||
|
||||
int32x4_t iacc0 = vdupq_n_s32(0);
|
||||
int32x4_t iacc1 = vdupq_n_s32(0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
|
||||
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
|
||||
}
|
||||
|
||||
vst1q_f32(s + (y * nc + x * 8), acc_row[0]);
|
||||
vst1q_f32(s + (y * nc + x * 8 + 4), acc_row[1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntw() != 8) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemm_q4_0_q8_0_aarch64_neon(n, s, vx, vy, nr, nc, ith, nth);
|
||||
return;
|
||||
}
|
||||
if (svcntw() == 8) {
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)8);
|
||||
size_t width = xend - x0;
|
||||
|
@ -1124,10 +977,9 @@ void ggml_gemm_q4_0_q8_0_aarch64_sve256(int n, float * GGML_RESTRICT s, const vo
|
|||
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [num_blocks] "r" (num_blocks), [res_stride] "r" (res_stride), [width] "r" (width)
|
||||
: "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
|
||||
);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#endif
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
|
||||
|
@ -1534,11 +1386,7 @@ void ggml_gemm_q4_0_q8_0_aarch64_neon(int n, float * GGML_RESTRICT s, const void
|
|||
: [b_ptr] "r" (b_ptr), [nr] "r" (nr), [num_blocks] "r" (num_blocks), [res_stride] "r" (res_stride), [width] "r" (width)
|
||||
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#if defined(__ARM_NEON)
|
||||
#elif defined(__ARM_NEON)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
|
||||
size_t width = xend - x0;
|
||||
|
@ -2006,94 +1854,3 @@ void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, con
|
|||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_q8_0_aarch64(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int64_t x0 = roundup((ith * nc) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * nc) / nth, (int64_t)4);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const block_q8_0x4 * b_ptr_start = (const block_q8_0x4 *) vx;
|
||||
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *) vy;
|
||||
|
||||
for (int64_t y = 0; y < nr / 4; y += nr / 4) {
|
||||
for (int64_t x = x0 / 4; x < xend / 4; x++) {
|
||||
const block_q8_0x4 ** a_ptrs = new const block_q8_0x4 * [nr / 4];
|
||||
|
||||
a_ptrs[0] = a_ptr_start + (y * a_nb);
|
||||
for (int i = 0; i < (nr / 4) - 1; i++) {
|
||||
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
|
||||
}
|
||||
|
||||
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulators
|
||||
float32x4_t * acc_rows = new float32x4_t[nr];
|
||||
for (int i = 0; i < nr; i++) {
|
||||
acc_rows[i] = vdupq_n_f32(0.0f);
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
|
||||
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
|
||||
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
|
||||
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
|
||||
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
|
||||
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
|
||||
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
|
||||
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x4_t col_scale_f16 = vld1_f16((const ggml_fp16_internal_t *)(b_ptr[b].d));
|
||||
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
for (int rp = 0; rp < nr / 4; rp++) {
|
||||
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
|
||||
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
|
||||
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
|
||||
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
|
||||
|
||||
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
|
||||
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
|
||||
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
|
||||
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
|
||||
|
||||
// Do the MMLAs into 2x2 matrices
|
||||
const int32x4_t iacc_mat_00 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_01 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
|
||||
const int32x4_t iacc_mat_10 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_11 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
|
||||
|
||||
// Straighten out to make 4 row vectors
|
||||
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_f16((const ggml_fp16_internal_t *)(a_ptrs[rp][b].d));
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
|
||||
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
|
||||
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
|
||||
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nr; i++) {
|
||||
vst1q_f32(s + ((y * 4 + i) * nc + x * 4), acc_rows[i]);
|
||||
}
|
||||
delete [] acc_rows;
|
||||
delete [] a_ptrs;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -24,17 +24,10 @@ block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int bloc
|
|||
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len);
|
||||
|
||||
// GEMV
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
|
||||
// GEMM
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q8_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#undef GGML_USE_LLAMAFILE
|
||||
#endif
|
||||
|
||||
|
@ -915,16 +915,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_sve256,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_sve256,
|
||||
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon,
|
||||
#elif defined(__ARM_NEON)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm,
|
||||
#endif
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64,
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -12242,7 +12234,7 @@ UseGgmlGemm1:;
|
|||
}
|
||||
}
|
||||
}
|
||||
if ((type == GGML_TYPE_Q4_0_AARCH64) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
|
||||
if (from_float_to_mat && gemm && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
|
||||
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) {
|
||||
from_float_to_mat((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4, ggml_cpu_has_matmul_int8() ? 8 : 4);
|
||||
wdata += row_size * 4;
|
||||
|
@ -12340,10 +12332,9 @@ UseGgmlGemm2:;
|
|||
//if (ith == 0)
|
||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||
|
||||
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (type == GGML_TYPE_Q4_0_AARCH64)) {
|
||||
gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
|
||||
}
|
||||
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 2) && (type == GGML_TYPE_Q4_0_AARCH64)) {
|
||||
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
|
||||
if (ne11 == 1) gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
|
||||
else {
|
||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
|
||||
}
|
||||
|
@ -12357,7 +12348,13 @@ UseGgmlGemm2:;
|
|||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
for (int row_iter = 0; row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue