Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization

This commit is contained in:
Dibakar Gope 2024-02-28 17:33:41 +00:00 committed by Dibakar Gope
parent 3fd62a6b1c
commit 002e36eaec
6 changed files with 1621 additions and 17 deletions

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#pragma once
//
@ -602,6 +603,11 @@ extern "C" {
void * extra; // extra things e.g. for ggml-cuda.cu
// char padding[4];
char padding[9];
void * rearranged_weight_gemv;
void * rearranged_weight_gemm;
bool weight_rearranged;
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@ -2422,6 +2428,15 @@ extern "C" {
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
GGML_API void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur);
GGML_API void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur);
GGML_API void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur);
GGML_API void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur);
GGML_API void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur);
GGML_API void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur);
GGML_API void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur);
GGML_API void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur);
#ifdef __cplusplus
}
#endif

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#pragma once
#include "ggml.h"
@ -609,6 +610,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
#ifdef __ARM_FEATURE_SVE
#include <arm_sve.h>
#endif // __ARM_FEATURE_SVE
// precomputed f32 table for f16 (256 KB)
// defined in ggml.c, initialized in ggml_init()
extern float ggml_table_f32_f16[1 << 16];

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
@ -14706,6 +14707,929 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
assert(k % QK_K == 0);
block_iq2_s * restrict y = vy;
quantize_row_iq2_s_reference(x, y, k);
// Routines to create the blocked formats
// Note input is array of pointers.
// The exact interleaving format needed is different for GEMM (using SMMLA)
// and GEMV (using SDOT) cases. For GEMM, we interleave 8 pairs of values
// at a time (with the two nibbles separated at runtime to give 2x2x8
// matrices). For GEMV, we need to interleave 4 pairs of values instead.
block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len) {
block_q4_0x4 out;
for (int i = 0; i < 4; i++) {
out.d[i] = in[i]->d;
}
for (int i = 0; i < QK4_0 * 2; i++) {
// We are interleaving 4 rows in blocks of 8, making a total of 32
// output bytes per block (2 MMLA input vectors). This repeats
// until we have processed the whole block.
//
// Per the comment above, for GEMV cases a similar process is used
// but with blocks of 4 instead, giving a single DOT input vector.
//
// In the case of q4, we add on 128 to convert the top nibble from
// "bias offset" form to pure sign form (this saves a subtract when
// we unpack it).
int src_offset = (i / (4 * block_len)) * block_len;
int src_id = (i % (4 * block_len)) / block_len;
src_offset += (i % block_len);
out.qs[i] = in[src_id]->qs[src_offset] + 0x80;
}
return out;
}
// 8-block version - see comments in code above
block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len) {
block_q4_0x8 out;
for (int i = 0; i < 8; i++) {
out.d[i] = in[i]->d;
}
for (int i = 0; i < QK4_0 * 4; i++) {
int src_offset = (i / (8 * block_len)) * block_len;
int src_id = (i % (8 * block_len)) / block_len;
src_offset += (i % block_len);
out.qs[i] = in[src_id]->qs[src_offset] + 0x80;
}
return out;
}
block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len) {
block_q8_0x4 out;
for (int i = 0; i < 4; i++) {
out.d[i] = in[i]->d;
}
for (int i = 0; i < QK8_0 * 4; i++) {
int src_offset = (i / (4 * block_len)) * block_len;
int src_id = (i % (4 * block_len)) / block_len;
src_offset += (i % block_len);
out.qs[i] = in[src_id]->qs[src_offset];
}
return out;
}
// 8-block version - see comments in code above
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len) {
block_q8_0x8 out;
for (int i = 0; i < 8; i++) {
out.d[i] = in[i]->d;
}
for (int i = 0; i < QK8_0 * 8; i++) {
int src_offset = (i / (8 * block_len)) * block_len;
int src_id = (i % (8 * block_len)) / block_len;
src_offset += (i % block_len);
out.qs[i] = in[src_id]->qs[src_offset];
}
return out;
}
void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
block_q8_0x2 * restrict y = vy;
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv[rows_interleaved][8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
float id[rows_interleaved];
for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) {
for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
id[row_iter] = d ? 1.0f / d : 0.0f;
y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
}
for (int j = 0; j < 4; j++) {
float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
vi = vcvtnq_s32_f32(v);
y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[1][2 * j], id[1]);
vi = vcvtnq_s32_f32(v);
y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
vi = vcvtnq_s32_f32(v);
y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
}
}
#endif
}
void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;
block_q8_0x4 * restrict y = vy;
#if defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
float32x4_t srcv[rows_interleaved][8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
float id[rows_interleaved];
for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) {
for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
const float amax = vmaxvq_f32(amaxv[0]);
const float d = amax / ((1 << 7) - 1);
id[row_iter] = d ? 1.0f / d : 0.0f;
y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
}
for (int j = 0; j < 4; j++) {
float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
int32x4_t vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[1][2 * j], id[1]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[2][2 * j], id[2]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[3][2 * j], id[3]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
vi = vcvtnq_s32_f32(v);
y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
}
}
#endif
}
inline int64_t roundup(const int64_t a, const int64_t b) {
int64_t rem = a % b;
if (rem) {
return a + b - rem;
} else {
return a;
}
}
void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_NEON)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
int64_t nb = n / QK4_0;
int64_t a_nb = n / QK8_0;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const block_q4_0x8 * b_ptr_start = vx;
const block_q8_0 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
float32x4_t acc_row[2];
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const uint8x16_t rhs_raw_vec_0_0 = vld1q_u8(b_ptr[b].qs);
const uint8x16_t rhs_raw_vec_1_0 = vld1q_u8(b_ptr[b].qs + 16);
const uint8x16_t rhs_raw_vec_0_1 = vld1q_u8(b_ptr[b].qs + 32);
const uint8x16_t rhs_raw_vec_1_1 = vld1q_u8(b_ptr[b].qs + 48);
const uint8x16_t rhs_raw_vec_0_2 = vld1q_u8(b_ptr[b].qs + 64);
const uint8x16_t rhs_raw_vec_1_2 = vld1q_u8(b_ptr[b].qs + 80);
const uint8x16_t rhs_raw_vec_0_3 = vld1q_u8(b_ptr[b].qs + 96);
const uint8x16_t rhs_raw_vec_1_3 = vld1q_u8(b_ptr[b].qs + 112);
const int8x16_t rhs_vec_0_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_0, m4b)), s8b);
const int8x16_t rhs_vec_0_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_1, m4b)), s8b);
const int8x16_t rhs_vec_0_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_2, m4b)), s8b);
const int8x16_t rhs_vec_0_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_3, m4b)), s8b);
const int8x16_t rhs_vec_1_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_0, m4b)), s8b);
const int8x16_t rhs_vec_1_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_1, m4b)), s8b);
const int8x16_t rhs_vec_1_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_2, m4b)), s8b);
const int8x16_t rhs_vec_1_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_3, m4b)), s8b);
const int8x16_t rhs_vec_0_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_0), 4);
const int8x16_t rhs_vec_0_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_1), 4);
const int8x16_t rhs_vec_0_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_2), 4);
const int8x16_t rhs_vec_0_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_3), 4);
const int8x16_t rhs_vec_1_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_0), 4);
const int8x16_t rhs_vec_1_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_1), 4);
const int8x16_t rhs_vec_1_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_2), 4);
const int8x16_t rhs_vec_1_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_3), 4);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d);
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d));
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
int32x4_t iacc0 = vdupq_n_s32(0);
int32x4_t iacc1 = vdupq_n_s32(0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
}
vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]);
vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]);
}
}
#endif
}
void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
int64_t nb = n / QK4_0;
int64_t a_nb = n / QK8_0;
const svuint8_t m4b = svdup_u8(0x0F);
const svint8_t s8b = svdup_s8(0x8);
const svbool_t ptrue = svptrue_b8();
const block_q4_0x8 * b_ptr_start = vx;
const block_q8_0 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
svfloat32_t acc_row = svdup_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const svuint8_t rhs_raw_vec_0_0 = svld1_u8(ptrue, b_ptr[b].qs);
const svuint8_t rhs_raw_vec_0_1 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 1);
const svuint8_t rhs_raw_vec_0_2 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 2);
const svuint8_t rhs_raw_vec_0_3 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 3);
const svint8_t rhs_vec_0_0_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_0), 4);
const svint8_t rhs_vec_0_1_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_1), 4);
const svint8_t rhs_vec_0_2_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_2), 4);
const svint8_t rhs_vec_0_3_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_3), 4);
const svint8_t rhs_vec_0_0_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_0, m4b)), s8b);
const svint8_t rhs_vec_0_1_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_1, m4b)), s8b);
const svint8_t rhs_vec_0_2_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_2, m4b)), s8b);
const svint8_t rhs_vec_0_3_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_3, m4b)), s8b);
// Scale values
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
svint32_t iacc = svdup_s32(0);
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
}
svst1(ptrue, s + (y * output_channels + x * 8), acc_row);
}
}
#endif
}
void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
int64_t nb = n / QK4_0;
int64_t a_nb = n / QK8_0;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const block_q4_0x4 * b_ptr_start = vx;
const block_q8_0x4 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width / 4; y += rows / 4) {
for (int64_t x = x0 / 4; x < xend / 4; x++) {
const block_q8_0x4 * a_ptrs[rows / 4];
a_ptrs[0] = a_ptr_start + (y * a_nb);
for (int i = 0; i < (rows / 4) - 1; i++) {
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
}
const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulators
float32x4_t acc_rows[rows];
for (int i = 0; i < rows; i++) {
acc_rows[i] = vdupq_n_f32(0.0f);
}
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs);
const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16);
const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32);
const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48);
// 4-bit -> 8-bit
const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b);
const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b);
const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b);
const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b);
const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4);
const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4);
const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4);
const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
// Process LHS in pairs of rows
for (int rp = 0; rp < rows / 4; rp++) {
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
// Do the MMLAs into 2x2 matrices
const int32x4_t iacc_mat_00 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
const int32x4_t iacc_mat_01 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
const int32x4_t iacc_mat_10 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
const int32x4_t iacc_mat_11 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
// Straighten out to make 4 row vectors
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d);
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
}
}
for (int i = 0; i < rows; i++) {
vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]);
}
}
}
#endif
}
void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
int rows = 2;
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
int64_t nb = n / QK4_0;
int64_t a_nb = n / QK8_0;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const block_q4_0x4 * b_ptr_start = vx;
const block_q8_0x2 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width / 2; y += rows / 2) {
for (int64_t x = x0 / 4; x < xend / 4; x++) {
const block_q8_0x2 * a_ptrs[rows / 2];
a_ptrs[0] = a_ptr_start + (y * a_nb);
const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulators
float32x4_t acc_rows[rows];
acc_rows[0] = vdupq_n_f32(0.0f);
acc_rows[1] = vdupq_n_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs);
const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16);
const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32);
const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48);
const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b);
const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b);
const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b);
const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b);
const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4);
const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4);
const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4);
const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
// Process LHS in pairs of rows
int rp = 0;
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16);
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32);
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48);
// Do the MMLAs into 2x2 matrices
const int32x4_t iacc_mat_00 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
const int32x4_t iacc_mat_01 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
// Straighten out to make 2 row vectors
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0]));
const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0);
const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1]));
const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1);
acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0));
acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1));
}
vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]);
vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]);
}
}
#endif
}
void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_NEON)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const block_q8_0x8 * b_ptr_start = vx;
const block_q8_0 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
float32x4_t acc_row[2];
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const int8x16_t rhs_vec_0_0_0 = vld1q_s8(b_ptr[b].qs);
const int8x16_t rhs_vec_1_0_0 = vld1q_s8(b_ptr[b].qs + 16);
const int8x16_t rhs_vec_0_1_0 = vld1q_s8(b_ptr[b].qs + 32);
const int8x16_t rhs_vec_1_1_0 = vld1q_s8(b_ptr[b].qs + 48);
const int8x16_t rhs_vec_0_2_0 = vld1q_s8(b_ptr[b].qs + 64);
const int8x16_t rhs_vec_1_2_0 = vld1q_s8(b_ptr[b].qs + 80);
const int8x16_t rhs_vec_0_3_0 = vld1q_s8(b_ptr[b].qs + 96);
const int8x16_t rhs_vec_1_3_0 = vld1q_s8(b_ptr[b].qs + 112);
const int8x16_t rhs_vec_0_0_1 = vld1q_s8(b_ptr[b].qs + 128);
const int8x16_t rhs_vec_1_0_1 = vld1q_s8(b_ptr[b].qs + 144);
const int8x16_t rhs_vec_0_1_1 = vld1q_s8(b_ptr[b].qs + 160);
const int8x16_t rhs_vec_1_1_1 = vld1q_s8(b_ptr[b].qs + 176);
const int8x16_t rhs_vec_0_2_1 = vld1q_s8(b_ptr[b].qs + 192);
const int8x16_t rhs_vec_1_2_1 = vld1q_s8(b_ptr[b].qs + 208);
const int8x16_t rhs_vec_0_3_1 = vld1q_s8(b_ptr[b].qs + 224);
const int8x16_t rhs_vec_1_3_1 = vld1q_s8(b_ptr[b].qs + 240);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d);
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d));
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
int32x4_t iacc0 = vdupq_n_s32(0);
int32x4_t iacc1 = vdupq_n_s32(0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
}
vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]);
vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]);
}
}
#endif
}
void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const svbool_t ptrue = svptrue_b8();
const block_q8_0x8 * b_ptr_start = vx;
const block_q8_0 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width; y++) {
for (int64_t x = x0 / 8; x < xend / 8; x++) {
// Pointers to LHS blocks
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
// Pointers to RHS blocks
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulator
svfloat32_t acc_row = svdup_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const svint8_t rhs_vec_0_0_0 = svld1_s8(ptrue, b_ptr[b].qs);
const svint8_t rhs_vec_0_1_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 1);
const svint8_t rhs_vec_0_2_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 2);
const svint8_t rhs_vec_0_3_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 3);
const svint8_t rhs_vec_0_0_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 4);
const svint8_t rhs_vec_0_1_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 5);
const svint8_t rhs_vec_0_2_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 6);
const svint8_t rhs_vec_0_3_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 7);
// Scale values
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
svint32_t iacc = svdup_s32(0);
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
}
svst1(ptrue, s + (y * output_channels + x * 8), acc_row);
}
}
#endif
}
void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const block_q8_0x4 * b_ptr_start = vx;
const block_q8_0x4 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width / 4; y += rows / 4) {
for (int64_t x = x0 / 4; x < xend / 4; x++) {
const block_q8_0x4 * a_ptrs[rows / 4];
a_ptrs[0] = a_ptr_start + (y * a_nb);
for (int i = 0; i < (rows / 4) - 1; i++) {
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
}
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulators
float32x4_t acc_rows[rows];
for (int i = 0; i < rows; i++) {
acc_rows[i] = vdupq_n_f32(0.0f);
}
for (int64_t b = 0; b < nb; b++) {
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
// Process LHS in pairs of rows
for (int rp = 0; rp < rows / 4; rp++) {
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
// Do the MMLAs into 2x2 matrices
const int32x4_t iacc_mat_00 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
const int32x4_t iacc_mat_01 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
const int32x4_t iacc_mat_10 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
const int32x4_t iacc_mat_11 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
// Straighten out to make 4 row vectors
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d);
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
}
}
for (int i = 0; i < rows; i++) {
vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]);
}
}
}
#endif
}
void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
int rows = 2;
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
int64_t nb = n / QK8_0;
int64_t a_nb = n / QK8_0;
const block_q8_0x4 * b_ptr_start = vx;
const block_q8_0x2 * a_ptr_start = vy;
for (int64_t y = 0; y < input_width / 2; y += rows / 2) {
for (int64_t x = x0 / 4; x < xend / 4; x++) {
const block_q8_0x2 * a_ptrs[rows / 2];
a_ptrs[0] = a_ptr_start + (y * a_nb);
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
// Master FP accumulators
float32x4_t acc_rows[rows];
acc_rows[0] = vdupq_n_f32(0.0f);
acc_rows[1] = vdupq_n_f32(0.0f);
for (int64_t b = 0; b < nb; b++) {
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
// Process LHS in pairs of rows
int rp = 0;
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16);
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32);
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48);
// Do the MMLAs into 2x2 matrices
const int32x4_t iacc_mat_00 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
const int32x4_t iacc_mat_01 =
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
// Straighten out to make 2 row vectors
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0]));
const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0);
const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1]));
const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1);
acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0));
acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1));
}
vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]);
vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]);
}
}
#endif
}
static bool validate_float(float f, size_t i) {

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#pragma once
#define GGML_COMMON_DECL_C
@ -7,6 +8,250 @@
// GGML internal header
#include <stdint.h>
#include <stddef.h>
#define QK4_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
typedef struct {
ggml_fp16_t d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define QK8_1 32
typedef struct {
float d; // delta
float s; // d * sum(qs[i])
int8_t qs[QK8_1]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
typedef struct {
ggml_fp16_t d[4]; // deltas for 4 q4_0 blocks
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
} block_q4_0x4;
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_fp16_t) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
typedef struct {
ggml_fp16_t d[8]; // deltas for 8 q4_0 blocks
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
} block_q4_0x8;
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_fp16_t) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
typedef struct {
ggml_fp16_t d[16]; // deltas for 16 q4_0 blocks
uint8_t qs[QK4_0 * 8]; // nibbles / quants for 16 q4_0 blocks
} block_q4_0x16;
static_assert(sizeof(block_q4_0x16) == 16 * sizeof(ggml_fp16_t) + QK4_0 * 8, "wrong q4_0x16 block size/padding");
typedef struct {
ggml_fp16_t d[64]; // deltas for 64 q4_0 blocks
uint8_t qs[QK4_0 * 32];// nibbles / quants for 64 q4_0 blocks
} block_q4_0x64;
static_assert(sizeof(block_q4_0x64) == 64 * sizeof(ggml_fp16_t) + QK4_0 * 32, "wrong q4_0x64 block size/padding");
typedef struct {
ggml_fp16_t d[2]; // deltas for 2 q8_0 blocks
int8_t qs[QK8_0 * 2]; // quants for 2 q8_0 blocks
} block_q8_0x2;
static_assert(sizeof(block_q8_0x2) == 2 * sizeof(ggml_fp16_t) + QK8_0 * 2, "wrong q8_0x2 block size/padding");
typedef struct {
ggml_fp16_t d[4]; // deltas for 4 q8_0 blocks
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
} block_q8_0x4;
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_fp16_t) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
typedef struct {
ggml_fp16_t d[8]; // deltas for 8 q8_0 blocks
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
} block_q8_0x8;
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_fp16_t) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
//
// Super-block quantization structures
//
// Super-block size
#ifdef GGML_QKK_64
#define QK_K 64
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define K_SCALE_SIZE 12
#endif
// 2-bit quantization
// weight is represented as x = a * q + b
// 16 blocks of 16 elements each
// Effectively 2.625 bits per weight
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
// 3-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 3.4375 bits per weight
#ifdef GGML_QKK_64
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[2];
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
#else
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[12]; // scales, quantized with 6 bits
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
#endif
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d[2]; // super-block scales/mins
uint8_t scales[2]; // 4-bit block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
#endif
// 5-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d; // super-block scale
int8_t scales[QK_K/16]; // 8-bit block scales
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
ggml_fp16_t d; // super-block scale
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
int8_t qs[QK_K]; // quants
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
// (Almost) "true" 2-bit quantization.
// Due to the need to use blocks as per ggml design, it ends up using
// 2.0625 bpw because of the 16-bit scale for each block of 256.
typedef struct {
ggml_fp16_t d;
uint16_t qs[QK_K/8];
} block_iq2_xxs;
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
// 2.3125 bpw quants
typedef struct {
ggml_fp16_t d;
uint16_t qs[QK_K/8];
uint8_t scales[QK_K/32];
} block_iq2_xs;
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
// (Almost) "true" 3-bit quantization.
// Due to the need to use blocks as per ggml design, it ends up using
// 3.0625 bpw because of the 16-bit scale for each block of 256.
typedef struct {
ggml_fp16_t d;
uint8_t qs[3*QK_K/8];
} block_iq3_xxs;
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
typedef struct {
ggml_fp16_t d;
uint8_t qs[QK_K/8];
uint8_t scales[QK_K/16];
} block_iq1_s;
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
// Non-linear quants
#define QK4_NL 32
typedef struct {
ggml_fp16_t d;
uint8_t qs[QK4_NL/2];
} block_iq4_nl;
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
#ifdef __cplusplus
extern "C" {
#endif
@ -127,6 +372,25 @@ void iq2xs_free_impl(enum ggml_type type);
void iq3xs_init_impl(int grid_size);
void iq3xs_free_impl(int grid_size);
block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len);
block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len);
block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len);
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len);
void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved);
void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved);
// GEMV
void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
// GEMM
void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
#ifdef __cplusplus
}
#endif

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
@ -473,6 +474,204 @@ int64_t ggml_cycles_per_ms(void) {
return CLOCKS_PER_SEC/1000;
}
#ifdef GGML_PERF
#define ggml_perf_time_ms() ggml_time_ms()
#define ggml_perf_time_us() ggml_time_us()
#define ggml_perf_cycles() ggml_cycles()
#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
#else
#define ggml_perf_time_ms() 0
#define ggml_perf_time_us() 0
#define ggml_perf_cycles() 0
#define ggml_perf_cycles_per_ms() 0
#endif
void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur) {
block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q4_0x8 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK4_0;
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
const block_q4_0 * in_ptrs[8];
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb);
for (int i = 0; i < 7; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT
out_ptr_B++;
for (int i = 0; i < 8; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
}
void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur) {
#if defined(__ARM_FEATURE_SVE)
if (svcntw() != 8) {
printf("ggml_gemv_q4_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n");
exit(1);
}
block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q4_0x8 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK4_0;
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
const block_q4_0 * in_ptrs[8];
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb);
for (int i = 0; i < 7; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT
out_ptr_B++;
for (int i = 0; i < 8; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
#endif
}
#if defined(__ARM_FEATURE_SVE)
static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_sve;
#elif defined(__ARM_NEON)
static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_neon;
#endif
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q4_0_weights_for_gemv(cur); }
#endif
void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur) {
block_q4_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q4_0x4 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK4_0;
for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) {
const block_q4_0 * in_ptrs[4];
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 4 * nb);
for (int i = 0; i < 3; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B =
make_block_q4_0x4(in_ptrs, 8); // block_len=8 for SMMLA
out_ptr_B++;
for (int i = 0; i < 4; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start;
}
void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur) {
block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q8_0x8 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK8_0;
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
const block_q8_0 * in_ptrs[8];
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb);
for (int i = 0; i < 7; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT
out_ptr_B++;
for (int i = 0; i < 8; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
}
void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur) {
#if defined(__ARM_FEATURE_SVE)
if (svcntw() != 8) {
printf("ggml_gemv_q8_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n");
exit(1);
}
block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q8_0x8 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK8_0;
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
const block_q8_0 * in_ptrs[8];
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb);
for (int i = 0; i < 7; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT
out_ptr_B++;
for (int i = 0; i < 8; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
#endif
}
#if defined(__ARM_FEATURE_SVE)
static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_sve;
#elif defined(__ARM_NEON)
static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_neon;
#endif
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q8_0_weights_for_gemv(cur); }
#endif
void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur) {
block_q8_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
block_q8_0x4 * out_ptr_B_start = out_ptr_B;
int64_t nb = cur->ne[0] / QK8_0;
for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) {
const block_q8_0 * in_ptrs[4];
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 4 * nb);
for (int i = 0; i < 3; i++) {
in_ptrs[i + 1] = in_ptrs[i] + nb;
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B =
make_block_q8_0x4(in_ptrs, 8); // block_len=8 for SMMLA
out_ptr_B++;
for (int i = 0; i < 4; i++) {
in_ptrs[i]++;
}
}
}
cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start;
}
//
// cross-platform UTF-8 file paths
//
@ -2605,6 +2804,10 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
*s = idx;
}
static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
//
// data types
//
@ -3647,6 +3850,9 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.name =*/ { 0 },
/*.extra =*/ NULL,
///*.padding =*/ { 0 },
/*.rearranged_weight_gemv =*/ NULL,
/*.rearranged_weight_gemm =*/ NULL,
/*.weight_rearranged =*/ false,
};
#ifdef __clang__
@ -12199,7 +12405,32 @@ UseGgmlGemm1:;
}
}
}
}
#if defined(__ARM_FEATURE_MATMUL_INT8)
if ((src0->weight_rearranged == true) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) {
quantize_row_q8_0_and_make_block_q8_0x4((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4);
wdata += row_size * 4;
}
for (int64_t i11 = (ne11 / 4) * 4; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i11 * nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
else {
#endif
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
#if defined(__ARM_FEATURE_MATMUL_INT8)
}
#endif
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@ -12275,25 +12506,141 @@ UseGgmlGemm2:;
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
//if (ith == 0)
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {
break;
#if defined(__ARM_FEATURE_MATMUL_INT8) && (defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE))
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (src0->weight_rearranged == true)) {
if (src0->type == GGML_TYPE_Q4_0) {
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels
} else if (src0->type == GGML_TYPE_Q8_0) {
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels
}
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 16) && (src0->weight_rearranged == true)) {
// use batch-sized 16, 8, and 4 GEMM kernels
if (src0->type == GGML_TYPE_Q4_0) {
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth);
}
int rows_processed = (ne11 / 16) * 16;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
} else if (src0->type == GGML_TYPE_Q8_0) {
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth);
}
int rows_processed = (ne11 / 16) * 16;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
}
} else if ((ggml_n_dims(src0) == 2) && (ne11 >= 8) && (src0->weight_rearranged == true)) {
// use batch-sized 8, and 4 GEMM kernels
if (src0->type == GGML_TYPE_Q4_0) {
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth);
}
int rows_processed = (ne11 / 8) * 8;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
}
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
} else if (src0->type == GGML_TYPE_Q8_0) {
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth);
}
int rows_processed = (ne11 / 8) * 8;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
}
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
}
} else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (src0->weight_rearranged == true)) {
// use batch-sized 4 GEMM kernel
if (src0->type == GGML_TYPE_Q4_0) {
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth);
}
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
} else if (src0->type == GGML_TYPE_Q8_0) {
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth);
}
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
}
}
#elif defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
if ((ggml_n_dims(src0) == 2) && (src0->weight_rearranged == true)) {
if (src0->type == GGML_TYPE_Q4_0) {
for (int row_iter = 0; row_iter < ne11; row_iter++) {
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
} else if (src0->type == GGML_TYPE_Q8_0) {
for (int row_iter = 0; row_iter < ne11; row_iter++) {
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
}
}
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
else {
#endif
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
}
#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
}
#endif
}
// ggml_compute_forward_mul_mat_id
@ -21891,4 +22238,26 @@ int ggml_cpu_has_matmul_int8(void) {
#endif
}
#if defined(__ARM_FEATURE_SVE)
static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_sve;
#elif defined(__ARM_NEON)
static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_neon;
#endif
#if defined(__ARM_FEATURE_SVE)
static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_sve;
#elif defined(__ARM_NEON)
static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_neon;
#endif
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
_ggml_gemv_q4_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth);
}
static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
_ggml_gemv_q8_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth);
}
#endif
////////////////////////////////////////////////////////////////////////////////

View file

@ -1,3 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
#define LLAMA_API_INTERNAL
#include "llama.h"
@ -4358,6 +4359,32 @@ struct llama_model_loader {
}
}
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
if ((cur->type == GGML_TYPE_Q4_0) && (cur->ne[1] % 4 == 0)) {
cur->weight_rearranged = true;
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
rearrange_q4_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
rearrange_q4_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels
#endif
}
else if ((cur->type == GGML_TYPE_Q8_0) && (cur->ne[1] % 4 == 0)) {
cur->weight_rearranged = true;
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
rearrange_q8_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
rearrange_q8_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels
#endif
}
else {
cur->weight_rearranged = false;
}
#else
cur->weight_rearranged = false;
#endif
size_done += n_size;
}