Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization
This commit is contained in:
parent
3fd62a6b1c
commit
002e36eaec
6 changed files with 1621 additions and 17 deletions
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#pragma once
|
||||
|
||||
//
|
||||
|
@ -602,6 +603,11 @@ extern "C" {
|
|||
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||
|
||||
// char padding[4];
|
||||
char padding[9];
|
||||
|
||||
void * rearranged_weight_gemv;
|
||||
void * rearranged_weight_gemm;
|
||||
bool weight_rearranged;
|
||||
};
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
|
@ -2422,6 +2428,15 @@ extern "C" {
|
|||
|
||||
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||
|
||||
GGML_API void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur);
|
||||
GGML_API void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
@ -609,6 +610,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|||
|
||||
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
// precomputed f32 table for f16 (256 KB)
|
||||
// defined in ggml.c, initialized in ggml_init()
|
||||
extern float ggml_table_f32_f16[1 << 16];
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
|
@ -14706,6 +14707,929 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
|
|||
assert(k % QK_K == 0);
|
||||
block_iq2_s * restrict y = vy;
|
||||
quantize_row_iq2_s_reference(x, y, k);
|
||||
|
||||
// Routines to create the blocked formats
|
||||
// Note input is array of pointers.
|
||||
// The exact interleaving format needed is different for GEMM (using SMMLA)
|
||||
// and GEMV (using SDOT) cases. For GEMM, we interleave 8 pairs of values
|
||||
// at a time (with the two nibbles separated at runtime to give 2x2x8
|
||||
// matrices). For GEMV, we need to interleave 4 pairs of values instead.
|
||||
block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len) {
|
||||
block_q4_0x4 out;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
out.d[i] = in[i]->d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < QK4_0 * 2; i++) {
|
||||
// We are interleaving 4 rows in blocks of 8, making a total of 32
|
||||
// output bytes per block (2 MMLA input vectors). This repeats
|
||||
// until we have processed the whole block.
|
||||
//
|
||||
// Per the comment above, for GEMV cases a similar process is used
|
||||
// but with blocks of 4 instead, giving a single DOT input vector.
|
||||
//
|
||||
// In the case of q4, we add on 128 to convert the top nibble from
|
||||
// "bias offset" form to pure sign form (this saves a subtract when
|
||||
// we unpack it).
|
||||
int src_offset = (i / (4 * block_len)) * block_len;
|
||||
int src_id = (i % (4 * block_len)) / block_len;
|
||||
src_offset += (i % block_len);
|
||||
|
||||
out.qs[i] = in[src_id]->qs[src_offset] + 0x80;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// 8-block version - see comments in code above
|
||||
block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len) {
|
||||
block_q4_0x8 out;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i]->d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < QK4_0 * 4; i++) {
|
||||
int src_offset = (i / (8 * block_len)) * block_len;
|
||||
int src_id = (i % (8 * block_len)) / block_len;
|
||||
src_offset += (i % block_len);
|
||||
|
||||
out.qs[i] = in[src_id]->qs[src_offset] + 0x80;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len) {
|
||||
block_q8_0x4 out;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
out.d[i] = in[i]->d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < QK8_0 * 4; i++) {
|
||||
int src_offset = (i / (4 * block_len)) * block_len;
|
||||
int src_id = (i % (4 * block_len)) / block_len;
|
||||
src_offset += (i % block_len);
|
||||
|
||||
out.qs[i] = in[src_id]->qs[src_offset];
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// 8-block version - see comments in code above
|
||||
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len) {
|
||||
block_q8_0x8 out;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i]->d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < QK8_0 * 8; i++) {
|
||||
int src_offset = (i / (8 * block_len)) * block_len;
|
||||
int src_id = (i % (8 * block_len)) / block_len;
|
||||
src_offset += (i % block_len);
|
||||
|
||||
out.qs[i] = in[src_id]->qs[src_offset];
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
const int nb = k / QK8_0;
|
||||
|
||||
block_q8_0x2 * restrict y = vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t srcv[rows_interleaved][8];
|
||||
float32x4_t asrcv[8];
|
||||
float32x4_t amaxv[8];
|
||||
float id[rows_interleaved];
|
||||
|
||||
for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) {
|
||||
for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
id[row_iter] = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
|
||||
}
|
||||
|
||||
for (int j = 0; j < 4; j++) {
|
||||
float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
|
||||
int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
v = vmulq_n_f32(srcv[1][2 * j], id[1]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
const int nb = k / QK8_0;
|
||||
|
||||
block_q8_0x4 * restrict y = vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t srcv[rows_interleaved][8];
|
||||
float32x4_t asrcv[8];
|
||||
float32x4_t amaxv[8];
|
||||
float id[rows_interleaved];
|
||||
|
||||
for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) {
|
||||
for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
|
||||
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
id[row_iter] = d ? 1.0f / d : 0.0f;
|
||||
|
||||
y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
|
||||
}
|
||||
|
||||
for (int j = 0; j < 4; j++) {
|
||||
float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
|
||||
int32x4_t vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
v = vmulq_n_f32(srcv[1][2 * j], id[1]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
v = vmulq_n_f32(srcv[2][2 * j], id[2]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
|
||||
|
||||
v = vmulq_n_f32(srcv[3][2 * j], id[3]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
|
||||
v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
|
||||
vi = vcvtnq_s32_f32(v);
|
||||
y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
|
||||
y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
|
||||
y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
|
||||
y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int64_t roundup(const int64_t a, const int64_t b) {
|
||||
int64_t rem = a % b;
|
||||
|
||||
if (rem) {
|
||||
return a + b - rem;
|
||||
} else {
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_NEON)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
|
||||
int64_t nb = n / QK4_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const block_q4_0x8 * b_ptr_start = vx;
|
||||
const block_q8_0 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
// Master FP accumulator
|
||||
float32x4_t acc_row[2];
|
||||
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const uint8x16_t rhs_raw_vec_0_0 = vld1q_u8(b_ptr[b].qs);
|
||||
const uint8x16_t rhs_raw_vec_1_0 = vld1q_u8(b_ptr[b].qs + 16);
|
||||
const uint8x16_t rhs_raw_vec_0_1 = vld1q_u8(b_ptr[b].qs + 32);
|
||||
const uint8x16_t rhs_raw_vec_1_1 = vld1q_u8(b_ptr[b].qs + 48);
|
||||
const uint8x16_t rhs_raw_vec_0_2 = vld1q_u8(b_ptr[b].qs + 64);
|
||||
const uint8x16_t rhs_raw_vec_1_2 = vld1q_u8(b_ptr[b].qs + 80);
|
||||
const uint8x16_t rhs_raw_vec_0_3 = vld1q_u8(b_ptr[b].qs + 96);
|
||||
const uint8x16_t rhs_raw_vec_1_3 = vld1q_u8(b_ptr[b].qs + 112);
|
||||
|
||||
const int8x16_t rhs_vec_0_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_0, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_0_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_1, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_0_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_2, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_0_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_3, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_1_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_0, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_1_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_1, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_1_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_2, m4b)), s8b);
|
||||
const int8x16_t rhs_vec_1_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_3, m4b)), s8b);
|
||||
|
||||
const int8x16_t rhs_vec_0_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_0), 4);
|
||||
const int8x16_t rhs_vec_0_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_1), 4);
|
||||
const int8x16_t rhs_vec_0_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_2), 4);
|
||||
const int8x16_t rhs_vec_0_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_3), 4);
|
||||
const int8x16_t rhs_vec_1_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_0), 4);
|
||||
const int8x16_t rhs_vec_1_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_1), 4);
|
||||
const int8x16_t rhs_vec_1_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_2), 4);
|
||||
const int8x16_t rhs_vec_1_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_3), 4);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
|
||||
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d));
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
|
||||
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
|
||||
|
||||
int32x4_t iacc0 = vdupq_n_s32(0);
|
||||
int32x4_t iacc1 = vdupq_n_s32(0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
|
||||
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
|
||||
}
|
||||
|
||||
vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]);
|
||||
vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
|
||||
int64_t nb = n / QK4_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const svuint8_t m4b = svdup_u8(0x0F);
|
||||
const svint8_t s8b = svdup_s8(0x8);
|
||||
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
const block_q4_0x8 * b_ptr_start = vx;
|
||||
const block_q8_0 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulator
|
||||
svfloat32_t acc_row = svdup_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const svuint8_t rhs_raw_vec_0_0 = svld1_u8(ptrue, b_ptr[b].qs);
|
||||
const svuint8_t rhs_raw_vec_0_1 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 1);
|
||||
const svuint8_t rhs_raw_vec_0_2 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 2);
|
||||
const svuint8_t rhs_raw_vec_0_3 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 3);
|
||||
|
||||
const svint8_t rhs_vec_0_0_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_0), 4);
|
||||
const svint8_t rhs_vec_0_1_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_1), 4);
|
||||
const svint8_t rhs_vec_0_2_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_2), 4);
|
||||
const svint8_t rhs_vec_0_3_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_3), 4);
|
||||
|
||||
const svint8_t rhs_vec_0_0_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_0, m4b)), s8b);
|
||||
const svint8_t rhs_vec_0_1_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_1, m4b)), s8b);
|
||||
const svint8_t rhs_vec_0_2_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_2, m4b)), s8b);
|
||||
const svint8_t rhs_vec_0_3_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_3, m4b)), s8b);
|
||||
|
||||
// Scale values
|
||||
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
|
||||
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
|
||||
|
||||
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
|
||||
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
|
||||
|
||||
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
|
||||
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
|
||||
|
||||
svint32_t iacc = svdup_s32(0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
|
||||
}
|
||||
|
||||
svst1(ptrue, s + (y * output_channels + x * 8), acc_row);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
|
||||
|
||||
int64_t nb = n / QK4_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const block_q4_0x4 * b_ptr_start = vx;
|
||||
const block_q8_0x4 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width / 4; y += rows / 4) {
|
||||
for (int64_t x = x0 / 4; x < xend / 4; x++) {
|
||||
const block_q8_0x4 * a_ptrs[rows / 4];
|
||||
|
||||
a_ptrs[0] = a_ptr_start + (y * a_nb);
|
||||
for (int i = 0; i < (rows / 4) - 1; i++) {
|
||||
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
|
||||
}
|
||||
|
||||
const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulators
|
||||
float32x4_t acc_rows[rows];
|
||||
for (int i = 0; i < rows; i++) {
|
||||
acc_rows[i] = vdupq_n_f32(0.0f);
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs);
|
||||
const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16);
|
||||
const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32);
|
||||
const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4);
|
||||
const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4);
|
||||
const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4);
|
||||
const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
for (int rp = 0; rp < rows / 4; rp++) {
|
||||
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
|
||||
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
|
||||
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
|
||||
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
|
||||
|
||||
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
|
||||
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
|
||||
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
|
||||
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
|
||||
|
||||
// Do the MMLAs into 2x2 matrices
|
||||
const int32x4_t iacc_mat_00 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_01 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
|
||||
const int32x4_t iacc_mat_10 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_11 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
|
||||
|
||||
// Straighten out to make 4 row vectors
|
||||
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d);
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
|
||||
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
|
||||
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
|
||||
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int rows = 2;
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
|
||||
|
||||
int64_t nb = n / QK4_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||
const int8x16_t s8b = vdupq_n_s8(0x8);
|
||||
|
||||
const block_q4_0x4 * b_ptr_start = vx;
|
||||
const block_q8_0x2 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width / 2; y += rows / 2) {
|
||||
for (int64_t x = x0 / 4; x < xend / 4; x++) {
|
||||
const block_q8_0x2 * a_ptrs[rows / 2];
|
||||
|
||||
a_ptrs[0] = a_ptr_start + (y * a_nb);
|
||||
|
||||
const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulators
|
||||
float32x4_t acc_rows[rows];
|
||||
acc_rows[0] = vdupq_n_f32(0.0f);
|
||||
acc_rows[1] = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs);
|
||||
const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16);
|
||||
const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32);
|
||||
const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48);
|
||||
|
||||
const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b);
|
||||
const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b);
|
||||
|
||||
const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4);
|
||||
const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4);
|
||||
const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4);
|
||||
const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
int rp = 0;
|
||||
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
|
||||
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16);
|
||||
|
||||
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32);
|
||||
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48);
|
||||
|
||||
// Do the MMLAs into 2x2 matrices
|
||||
const int32x4_t iacc_mat_00 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_01 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
|
||||
|
||||
// Straighten out to make 2 row vectors
|
||||
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
|
||||
const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0]));
|
||||
const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0);
|
||||
const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1]));
|
||||
const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1);
|
||||
|
||||
acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0));
|
||||
acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1));
|
||||
}
|
||||
|
||||
vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]);
|
||||
vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_NEON)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const block_q8_0x8 * b_ptr_start = vx;
|
||||
const block_q8_0 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
// Master FP accumulator
|
||||
float32x4_t acc_row[2];
|
||||
acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const int8x16_t rhs_vec_0_0_0 = vld1q_s8(b_ptr[b].qs);
|
||||
const int8x16_t rhs_vec_1_0_0 = vld1q_s8(b_ptr[b].qs + 16);
|
||||
const int8x16_t rhs_vec_0_1_0 = vld1q_s8(b_ptr[b].qs + 32);
|
||||
const int8x16_t rhs_vec_1_1_0 = vld1q_s8(b_ptr[b].qs + 48);
|
||||
const int8x16_t rhs_vec_0_2_0 = vld1q_s8(b_ptr[b].qs + 64);
|
||||
const int8x16_t rhs_vec_1_2_0 = vld1q_s8(b_ptr[b].qs + 80);
|
||||
const int8x16_t rhs_vec_0_3_0 = vld1q_s8(b_ptr[b].qs + 96);
|
||||
const int8x16_t rhs_vec_1_3_0 = vld1q_s8(b_ptr[b].qs + 112);
|
||||
const int8x16_t rhs_vec_0_0_1 = vld1q_s8(b_ptr[b].qs + 128);
|
||||
const int8x16_t rhs_vec_1_0_1 = vld1q_s8(b_ptr[b].qs + 144);
|
||||
const int8x16_t rhs_vec_0_1_1 = vld1q_s8(b_ptr[b].qs + 160);
|
||||
const int8x16_t rhs_vec_1_1_1 = vld1q_s8(b_ptr[b].qs + 176);
|
||||
const int8x16_t rhs_vec_0_2_1 = vld1q_s8(b_ptr[b].qs + 192);
|
||||
const int8x16_t rhs_vec_1_2_1 = vld1q_s8(b_ptr[b].qs + 208);
|
||||
const int8x16_t rhs_vec_0_3_1 = vld1q_s8(b_ptr[b].qs + 224);
|
||||
const int8x16_t rhs_vec_1_3_1 = vld1q_s8(b_ptr[b].qs + 240);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16));
|
||||
const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d));
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs);
|
||||
const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16);
|
||||
|
||||
int32x4_t iacc0 = vdupq_n_s32(0);
|
||||
int32x4_t iacc1 = vdupq_n_s32(0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3);
|
||||
iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32));
|
||||
acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32));
|
||||
}
|
||||
|
||||
vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]);
|
||||
vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const svbool_t ptrue = svptrue_b8();
|
||||
|
||||
const block_q8_0x8 * b_ptr_start = vx;
|
||||
const block_q8_0 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width; y++) {
|
||||
for (int64_t x = x0 / 8; x < xend / 8; x++) {
|
||||
// Pointers to LHS blocks
|
||||
const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb);
|
||||
// Pointers to RHS blocks
|
||||
const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulator
|
||||
svfloat32_t acc_row = svdup_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const svint8_t rhs_vec_0_0_0 = svld1_s8(ptrue, b_ptr[b].qs);
|
||||
const svint8_t rhs_vec_0_1_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 1);
|
||||
const svint8_t rhs_vec_0_2_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 2);
|
||||
const svint8_t rhs_vec_0_3_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 3);
|
||||
const svint8_t rhs_vec_0_0_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 4);
|
||||
const svint8_t rhs_vec_0_1_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 5);
|
||||
const svint8_t rhs_vec_0_2_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 6);
|
||||
const svint8_t rhs_vec_0_3_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 7);
|
||||
|
||||
// Scale values
|
||||
const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d));
|
||||
const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16);
|
||||
|
||||
const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d);
|
||||
const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16);
|
||||
|
||||
const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs);
|
||||
const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16);
|
||||
|
||||
svint32_t iacc = svdup_s32(0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2);
|
||||
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3);
|
||||
iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3);
|
||||
|
||||
acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32));
|
||||
}
|
||||
|
||||
svst1(ptrue, s + (y * output_channels + x * 8), acc_row);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const block_q8_0x4 * b_ptr_start = vx;
|
||||
const block_q8_0x4 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width / 4; y += rows / 4) {
|
||||
for (int64_t x = x0 / 4; x < xend / 4; x++) {
|
||||
const block_q8_0x4 * a_ptrs[rows / 4];
|
||||
|
||||
a_ptrs[0] = a_ptr_start + (y * a_nb);
|
||||
for (int i = 0; i < (rows / 4) - 1; i++) {
|
||||
a_ptrs[i + 1] = a_ptrs[i] + a_nb;
|
||||
}
|
||||
|
||||
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulators
|
||||
float32x4_t acc_rows[rows];
|
||||
for (int i = 0; i < rows; i++) {
|
||||
acc_rows[i] = vdupq_n_f32(0.0f);
|
||||
}
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
// Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers)
|
||||
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
|
||||
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
|
||||
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
|
||||
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
|
||||
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
|
||||
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
|
||||
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
|
||||
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
for (int rp = 0; rp < rows / 4; rp++) {
|
||||
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
|
||||
const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16);
|
||||
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32);
|
||||
const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48);
|
||||
|
||||
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64);
|
||||
const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80);
|
||||
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96);
|
||||
const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112);
|
||||
|
||||
// Do the MMLAs into 2x2 matrices
|
||||
const int32x4_t iacc_mat_00 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_01 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
|
||||
const int32x4_t iacc_mat_10 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_11 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3);
|
||||
|
||||
// Straighten out to make 4 row vectors
|
||||
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11)));
|
||||
|
||||
const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d);
|
||||
const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16);
|
||||
|
||||
acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0));
|
||||
acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1));
|
||||
acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2));
|
||||
acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < rows; i++) {
|
||||
vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
int rows = 2;
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4);
|
||||
|
||||
int64_t nb = n / QK8_0;
|
||||
int64_t a_nb = n / QK8_0;
|
||||
|
||||
const block_q8_0x4 * b_ptr_start = vx;
|
||||
const block_q8_0x2 * a_ptr_start = vy;
|
||||
|
||||
for (int64_t y = 0; y < input_width / 2; y += rows / 2) {
|
||||
for (int64_t x = x0 / 4; x < xend / 4; x++) {
|
||||
const block_q8_0x2 * a_ptrs[rows / 2];
|
||||
|
||||
a_ptrs[0] = a_ptr_start + (y * a_nb);
|
||||
|
||||
const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb);
|
||||
|
||||
// Master FP accumulators
|
||||
float32x4_t acc_rows[rows];
|
||||
acc_rows[0] = vdupq_n_f32(0.0f);
|
||||
acc_rows[1] = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs);
|
||||
const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16);
|
||||
const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32);
|
||||
const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48);
|
||||
const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64);
|
||||
const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80);
|
||||
const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96);
|
||||
const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112);
|
||||
|
||||
// Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32
|
||||
const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d);
|
||||
const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16);
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
int rp = 0;
|
||||
const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs);
|
||||
const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16);
|
||||
|
||||
const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32);
|
||||
const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48);
|
||||
|
||||
// Do the MMLAs into 2x2 matrices
|
||||
const int32x4_t iacc_mat_00 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3);
|
||||
const int32x4_t iacc_mat_01 =
|
||||
vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3);
|
||||
|
||||
// Straighten out to make 2 row vectors
|
||||
const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01)));
|
||||
|
||||
const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0]));
|
||||
const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0);
|
||||
const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1]));
|
||||
const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1);
|
||||
|
||||
acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0));
|
||||
acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1));
|
||||
}
|
||||
vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]);
|
||||
vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool validate_float(float f, size_t i) {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
@ -7,6 +8,250 @@
|
|||
|
||||
// GGML internal header
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||
} block_q4_0;
|
||||
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
ggml_fp16_t m; // min
|
||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||
} block_q4_1;
|
||||
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||
|
||||
#define QK5_0 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
||||
} block_q5_0;
|
||||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||
|
||||
#define QK5_1 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
ggml_fp16_t m; // min
|
||||
uint8_t qh[4]; // 5-th bit of quants
|
||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||
} block_q5_1;
|
||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||
|
||||
#define QK8_0 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // delta
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
#define QK8_1 32
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
float s; // d * sum(qs[i])
|
||||
int8_t qs[QK8_1]; // quants
|
||||
} block_q8_1;
|
||||
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[4]; // deltas for 4 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
|
||||
} block_q4_0x4;
|
||||
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_fp16_t) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[8]; // deltas for 8 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
|
||||
} block_q4_0x8;
|
||||
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_fp16_t) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[16]; // deltas for 16 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 8]; // nibbles / quants for 16 q4_0 blocks
|
||||
} block_q4_0x16;
|
||||
static_assert(sizeof(block_q4_0x16) == 16 * sizeof(ggml_fp16_t) + QK4_0 * 8, "wrong q4_0x16 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[64]; // deltas for 64 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 32];// nibbles / quants for 64 q4_0 blocks
|
||||
} block_q4_0x64;
|
||||
static_assert(sizeof(block_q4_0x64) == 64 * sizeof(ggml_fp16_t) + QK4_0 * 32, "wrong q4_0x64 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[2]; // deltas for 2 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 2]; // quants for 2 q8_0 blocks
|
||||
} block_q8_0x2;
|
||||
static_assert(sizeof(block_q8_0x2) == 2 * sizeof(ggml_fp16_t) + QK8_0 * 2, "wrong q8_0x2 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[4]; // deltas for 4 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
|
||||
} block_q8_0x4;
|
||||
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_fp16_t) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d[8]; // deltas for 8 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
|
||||
} block_q8_0x8;
|
||||
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_fp16_t) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
|
||||
|
||||
//
|
||||
// Super-block quantization structures
|
||||
//
|
||||
|
||||
// Super-block size
|
||||
#ifdef GGML_QKK_64
|
||||
#define QK_K 64
|
||||
#define K_SCALE_SIZE 4
|
||||
#else
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
#endif
|
||||
|
||||
// 2-bit quantization
|
||||
// weight is represented as x = a * q + b
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 2.625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||
uint8_t qs[QK_K/4]; // quants
|
||||
ggml_fp16_t d; // super-block scale for quantized scales
|
||||
ggml_fp16_t dmin; // super-block scale for quantized mins
|
||||
} block_q2_K;
|
||||
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
||||
|
||||
// 3-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 3.4375 bits per weight
|
||||
#ifdef GGML_QKK_64
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||
uint8_t scales[2];
|
||||
ggml_fp16_t d; // super-block scale
|
||||
} block_q3_K;
|
||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
|
||||
#else
|
||||
typedef struct {
|
||||
uint8_t hmask[QK_K/8]; // quants - high bit
|
||||
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
||||
uint8_t scales[12]; // scales, quantized with 6 bits
|
||||
ggml_fp16_t d; // super-block scale
|
||||
} block_q3_K;
|
||||
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
|
||||
#endif
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
#ifdef GGML_QKK_64
|
||||
typedef struct {
|
||||
ggml_fp16_t d[2]; // super-block scales/mins
|
||||
uint8_t scales[2]; // 4-bit block scales/mins
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
|
||||
#else
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // super-block scale for quantized scales
|
||||
ggml_fp16_t dmin; // super-block scale for quantized mins
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
|
||||
#endif
|
||||
|
||||
// 5-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 5.5 bits per weight
|
||||
#ifdef GGML_QKK_64
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // super-block scale
|
||||
int8_t scales[QK_K/16]; // 8-bit block scales
|
||||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_K;
|
||||
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
|
||||
#else
|
||||
typedef struct {
|
||||
ggml_fp16_t d; // super-block scale for quantized scales
|
||||
ggml_fp16_t dmin; // super-block scale for quantized mins
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K/8]; // quants, high bit
|
||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||
} block_q5_K;
|
||||
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
||||
#endif
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
ggml_fp16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
|
||||
|
||||
// This is only used for intermediate quantization and dot products
|
||||
typedef struct {
|
||||
float d; // delta
|
||||
int8_t qs[QK_K]; // quants
|
||||
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
|
||||
} block_q8_K;
|
||||
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
||||
|
||||
// (Almost) "true" 2-bit quantization.
|
||||
// Due to the need to use blocks as per ggml design, it ends up using
|
||||
// 2.0625 bpw because of the 16-bit scale for each block of 256.
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint16_t qs[QK_K/8];
|
||||
} block_iq2_xxs;
|
||||
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
||||
|
||||
// 2.3125 bpw quants
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint16_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/32];
|
||||
} block_iq2_xs;
|
||||
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
|
||||
|
||||
// (Almost) "true" 3-bit quantization.
|
||||
// Due to the need to use blocks as per ggml design, it ends up using
|
||||
// 3.0625 bpw because of the 16-bit scale for each block of 256.
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint8_t qs[3*QK_K/8];
|
||||
} block_iq3_xxs;
|
||||
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint8_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/16];
|
||||
} block_iq1_s;
|
||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||
|
||||
// Non-linear quants
|
||||
#define QK4_NL 32
|
||||
typedef struct {
|
||||
ggml_fp16_t d;
|
||||
uint8_t qs[QK4_NL/2];
|
||||
} block_iq4_nl;
|
||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
@ -127,6 +372,25 @@ void iq2xs_free_impl(enum ggml_type type);
|
|||
void iq3xs_init_impl(int grid_size);
|
||||
void iq3xs_free_impl(int grid_size);
|
||||
|
||||
block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len);
|
||||
block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len);
|
||||
block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len);
|
||||
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len);
|
||||
void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved);
|
||||
void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved);
|
||||
|
||||
// GEMV
|
||||
void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
|
||||
// GEMM
|
||||
void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
403
ggml/src/ggml.c
403
ggml/src/ggml.c
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||
|
||||
|
@ -473,6 +474,204 @@ int64_t ggml_cycles_per_ms(void) {
|
|||
return CLOCKS_PER_SEC/1000;
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
#define ggml_perf_time_ms() ggml_time_ms()
|
||||
#define ggml_perf_time_us() ggml_time_us()
|
||||
#define ggml_perf_cycles() ggml_cycles()
|
||||
#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms()
|
||||
#else
|
||||
#define ggml_perf_time_ms() 0
|
||||
#define ggml_perf_time_us() 0
|
||||
#define ggml_perf_cycles() 0
|
||||
#define ggml_perf_cycles_per_ms() 0
|
||||
#endif
|
||||
|
||||
void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur) {
|
||||
block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q4_0x8 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK4_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
|
||||
const block_q4_0 * in_ptrs[8];
|
||||
|
||||
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb);
|
||||
for (int i = 0; i < 7; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
|
||||
}
|
||||
|
||||
void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
if (svcntw() != 8) {
|
||||
printf("ggml_gemv_q4_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q4_0x8 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK4_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
|
||||
const block_q4_0 * in_ptrs[8];
|
||||
|
||||
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb);
|
||||
for (int i = 0; i < 7; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_sve;
|
||||
#elif defined(__ARM_NEON)
|
||||
static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_neon;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q4_0_weights_for_gemv(cur); }
|
||||
#endif
|
||||
|
||||
void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur) {
|
||||
block_q4_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q4_0x4 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK4_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) {
|
||||
const block_q4_0 * in_ptrs[4];
|
||||
|
||||
in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 4 * nb);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B =
|
||||
make_block_q4_0x4(in_ptrs, 8); // block_len=8 for SMMLA
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start;
|
||||
}
|
||||
|
||||
void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur) {
|
||||
block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q8_0x8 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK8_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
|
||||
const block_q8_0 * in_ptrs[8];
|
||||
|
||||
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb);
|
||||
for (int i = 0; i < 7; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
|
||||
}
|
||||
|
||||
void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
if (svcntw() != 8) {
|
||||
printf("ggml_gemv_q8_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q8_0x8 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK8_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) {
|
||||
const block_q8_0 * in_ptrs[8];
|
||||
|
||||
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb);
|
||||
for (int i = 0; i < 7; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_sve;
|
||||
#elif defined(__ARM_NEON)
|
||||
static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_neon;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q8_0_weights_for_gemv(cur); }
|
||||
#endif
|
||||
|
||||
void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur) {
|
||||
block_q8_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data;
|
||||
block_q8_0x4 * out_ptr_B_start = out_ptr_B;
|
||||
int64_t nb = cur->ne[0] / QK8_0;
|
||||
|
||||
for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) {
|
||||
const block_q8_0 * in_ptrs[4];
|
||||
|
||||
in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 4 * nb);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
in_ptrs[i + 1] = in_ptrs[i] + nb;
|
||||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B =
|
||||
make_block_q8_0x4(in_ptrs, 8); // block_len=8 for SMMLA
|
||||
out_ptr_B++;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start;
|
||||
}
|
||||
|
||||
//
|
||||
// cross-platform UTF-8 file paths
|
||||
//
|
||||
|
@ -2605,6 +2804,10 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
|
|||
*s = idx;
|
||||
}
|
||||
|
||||
static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
|
||||
static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth);
|
||||
|
||||
//
|
||||
// data types
|
||||
//
|
||||
|
@ -3647,6 +3850,9 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||
/*.name =*/ { 0 },
|
||||
/*.extra =*/ NULL,
|
||||
///*.padding =*/ { 0 },
|
||||
/*.rearranged_weight_gemv =*/ NULL,
|
||||
/*.rearranged_weight_gemm =*/ NULL,
|
||||
/*.weight_rearranged =*/ false,
|
||||
};
|
||||
|
||||
#ifdef __clang__
|
||||
|
@ -12199,7 +12405,32 @@ UseGgmlGemm1:;
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if ((src0->weight_rearranged == true) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
|
||||
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) {
|
||||
quantize_row_q8_0_and_make_block_q8_0x4((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4);
|
||||
wdata += row_size * 4;
|
||||
}
|
||||
for (int64_t i11 = (ne11 / 4) * 4; i11 < ne11; ++i11) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i11 * nb11), (void *) wdata, ne10);
|
||||
wdata += row_size;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
else {
|
||||
#endif
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||
wdata += row_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
|
@ -12275,25 +12506,141 @@ UseGgmlGemm2:;
|
|||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
//if (ith == 0)
|
||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8) && (defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE))
|
||||
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (src0->weight_rearranged == true)) {
|
||||
if (src0->type == GGML_TYPE_Q4_0) {
|
||||
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1);
|
||||
}
|
||||
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 16) && (src0->weight_rearranged == true)) {
|
||||
// use batch-sized 16, 8, and 4 GEMM kernels
|
||||
if (src0->type == GGML_TYPE_Q4_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 16) * 16;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 16) * 16;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
}
|
||||
} else if ((ggml_n_dims(src0) == 2) && (ne11 >= 8) && (src0->weight_rearranged == true)) {
|
||||
// use batch-sized 8, and 4 GEMM kernels
|
||||
if (src0->type == GGML_TYPE_Q4_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
|
||||
}
|
||||
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm,
|
||||
(const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth);
|
||||
}
|
||||
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
}
|
||||
} else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (src0->weight_rearranged == true)) {
|
||||
// use batch-sized 4 GEMM kernel
|
||||
if (src0->type == GGML_TYPE_Q4_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
|
||||
ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth);
|
||||
}
|
||||
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
|
||||
ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth);
|
||||
}
|
||||
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
if ((ggml_n_dims(src0) == 2) && (src0->weight_rearranged == true)) {
|
||||
if (src0->type == GGML_TYPE_Q4_0) {
|
||||
for (int row_iter = 0; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_Q8_0) {
|
||||
for (int row_iter = 0; row_iter < ne11; row_iter++) {
|
||||
ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
else {
|
||||
#endif
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1);
|
||||
}
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
|
@ -21891,4 +22238,26 @@ int ggml_cpu_has_matmul_int8(void) {
|
|||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_sve;
|
||||
#elif defined(__ARM_NEON)
|
||||
static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_neon;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_sve;
|
||||
#elif defined(__ARM_NEON)
|
||||
static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_neon;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
_ggml_gemv_q4_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth);
|
||||
}
|
||||
|
||||
static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
_ggml_gemv_q8_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
|
@ -4358,6 +4359,32 @@ struct llama_model_loader {
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if ((cur->type == GGML_TYPE_Q4_0) && (cur->ne[1] % 4 == 0)) {
|
||||
cur->weight_rearranged = true;
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
rearrange_q4_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
rearrange_q4_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels
|
||||
#endif
|
||||
}
|
||||
else if ((cur->type == GGML_TYPE_Q8_0) && (cur->ne[1] % 4 == 0)) {
|
||||
cur->weight_rearranged = true;
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)
|
||||
rearrange_q8_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
rearrange_q8_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
cur->weight_rearranged = false;
|
||||
}
|
||||
#else
|
||||
cur->weight_rearranged = false;
|
||||
#endif
|
||||
|
||||
size_done += n_size;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue