ggml : add mmla kernels for quantized GEMM (#4966)

* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q8_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_0_q8_0 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm

armv8.2-a and above supports MMLA instructions that have higher
throughput than DOT. this commit adds mmla kernel for
q4_1_q8_1 gemm. The feature is enabled if the platform supports
"__ARM_FEATURE_MATMUL_INT8"

On AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.

* ggml: update unit tests for the new vec_dot interface

* llama.cpp: add MATMUL_INT8 capability to system_info
This commit is contained in:
snadampal 2024-02-11 07:22:33 -06:00 committed by GitHub
parent e4640d8fdf
commit a07d0fee1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 441 additions and 88 deletions

View file

@ -49,6 +49,8 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define UNUSED GGML_UNUSED
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@ -3677,15 +3679,88 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_0 * restrict vx0 = vx;
const block_q4_0 * restrict vx1 = vx + bx;
const block_q8_0 * restrict vy0 = vy;
const block_q8_0 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q4_0 * restrict b_x0 = &vx0[i];
const block_q4_0 * restrict b_x1 = &vx1[i];
const block_q8_0 * restrict b_y0 = &vy0[i];
const block_q8_0 * restrict b_y1 = &vy1[i];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -3967,15 +4042,89 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
#endif
}
void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_1 * restrict vx0 = vx;
const block_q4_1 * restrict vx1 = vx + bx;
const block_q8_1 * restrict vy0 = vy;
const block_q8_1 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t summs0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q4_1 * restrict b_x0 = &vx0[i];
const block_q4_1 * restrict b_x1 = &vx1[i];
const block_q8_1 * restrict b_y0 = &vy0[i];
const block_q8_1 * restrict b_y1 = &vy1[i];
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
summs0 += summs_t;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
// 4-bit -> 8-bit
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
// mmla into int32x4_t
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
sumv2 = sumv2 + summs0;
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
// TODO: add WASM SIMD
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -4107,12 +4256,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
#endif
}
void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
@ -4393,12 +4547,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
#endif
}
void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1;
const int nb = n / qk;
assert(n % qk == 0);
assert(qk == QK5_1);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_1 * restrict x = vx;
const block_q8_1 * restrict y = vy;
@ -4692,15 +4851,75 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
#endif
}
void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
const block_q8_0 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q8_0 * restrict vx0 = vx;
const block_q8_0 * restrict vx1 = vx + bx;
const block_q8_0 * restrict vy0 = vy;
const block_q8_0 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q8_0 * restrict b_x0 = &vx0[i];
const block_q8_0 * restrict b_y0 = &vy0[i];
const block_q8_0 * restrict b_x1 = &vx1[i];
const block_q8_0 * restrict b_y1 = &vy1[i];
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -4795,7 +5014,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
}
#if QK_K == 256
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -5171,7 +5395,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#else
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -5429,8 +5658,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif
#if QK_K == 256
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
@ -5949,8 +6183,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
#else
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -6292,8 +6531,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
#endif
#if QK_K == 256
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -6648,8 +6892,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#endif
}
#else
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -6891,8 +7140,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#endif
#if QK_K == 256
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -7311,8 +7565,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#else
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -7577,8 +7836,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#if QK_K == 256
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -8009,8 +8273,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
#else
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -8339,8 +8608,13 @@ static const int8_t keven_signs_q2xs[1024] = {
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
};
void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq2_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -8462,8 +8736,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
#endif
}
void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq2_xs * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -8682,8 +8961,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
}
// TODO
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_xxs * restrict x = vx;
const block_q8_K * restrict y = vy;