QK_K = 64 tests pass on ARM_NEON and Metal
Sadly, that does not mean it actually works.
This commit is contained in:
parent
28e6146c11
commit
de64e061da
2 changed files with 31 additions and 29 deletions
|
@ -2560,12 +2560,16 @@ typedef struct {
|
||||||
uint8_t qs[QK4_NL/2];
|
uint8_t qs[QK4_NL/2];
|
||||||
} block_iq4_nl;
|
} block_iq4_nl;
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
#define block_iq4_xs block_iq4_nl
|
||||||
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d;
|
half d;
|
||||||
uint16_t scales_h;
|
uint16_t scales_h;
|
||||||
uint8_t scales_l[QK_K/64];
|
uint8_t scales_l[QK_K/64];
|
||||||
uint8_t qs[QK_K/2];
|
uint8_t qs[QK_K/2];
|
||||||
} block_iq4_xs;
|
} block_iq4_xs;
|
||||||
|
#endif
|
||||||
|
|
||||||
//====================================== dot products =========================
|
//====================================== dot products =========================
|
||||||
|
|
||||||
|
@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
const int ix = tiisg;
|
const int ix = tiisg;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix;
|
||||||
|
@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
|
||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
(void) x;
|
|
||||||
(void) y;
|
|
||||||
(void) yl;
|
|
||||||
(void) nb32;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
const int ix = tiisg;
|
const int ix = tiisg;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix;
|
||||||
|
@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||||
|
|
||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
(void) x;
|
|
||||||
(void) y;
|
|
||||||
(void) yl;
|
|
||||||
(void) nb32;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
const int ix = tiisg;
|
const int ix = tiisg;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix;
|
||||||
|
@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||||
|
|
||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
(void) x;
|
|
||||||
(void) y;
|
|
||||||
(void) yl;
|
|
||||||
(void) nb32;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
|
|
||||||
const int nb32 = nb * (QK_K / 32);
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
const int ix = tiisg/2;
|
const int ix = tiisg/2;
|
||||||
const int il = tiisg%2;
|
const int il = tiisg%2;
|
||||||
|
|
||||||
|
@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
|
|
||||||
y4 += 16 * 32;
|
y4 += 16 * 32;
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
(void) x;
|
|
||||||
(void) y;
|
|
||||||
(void) yl;
|
|
||||||
(void) nb32;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if QK_K != 64
|
||||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
||||||
kernel void kernel_mul_mv_iq1_s_f32(
|
kernel void kernel_mul_mv_iq1_s_f32(
|
||||||
|
@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||||
|
#else
|
||||||
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
//============================= templates and their specializations =============================
|
//============================= templates and their specializations =============================
|
||||||
|
@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||||
|
#if QK_K == 64
|
||||||
|
dequantize_iq4_nl(xb, il, reg);
|
||||||
|
#else
|
||||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||||
const int ib32 = il/2;
|
const int ib32 = il/2;
|
||||||
il = il%2;
|
il = il%2;
|
||||||
|
@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
||||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
|
@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
|
||||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
|
#if QK_K == 64
|
||||||
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||||
|
#else
|
||||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
#endif
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-matrix multiplication
|
// matrix-matrix multiplication
|
||||||
|
@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
|
||||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
|
#if QK_K == 64
|
||||||
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
||||||
|
#else
|
||||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
#endif
|
||||||
|
|
||||||
//
|
//
|
||||||
// indirect matrix-matrix multiplication
|
// indirect matrix-matrix multiplication
|
||||||
|
@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
|
#if QK_K == 64
|
||||||
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||||
|
#else
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
#endif
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-vector multiplication
|
// matrix-vector multiplication
|
||||||
|
@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
||||||
|
|
||||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||||
|
|
||||||
|
#if QK_K == 64
|
||||||
|
kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
|
#else
|
||||||
kernel_mul_mv_iq4_xs_f32_impl(
|
kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
|
#endif
|
||||||
src0[id],
|
src0[id],
|
||||||
(device const float *) (src1 + bid*nb11),
|
(device const float *) (src1 + bid*nb11),
|
||||||
dst + bid*ne0,
|
dst + bid*ne0,
|
||||||
|
|
|
@ -10262,7 +10262,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#if defined __ARM_NEON
|
#if defined __ARM_NEON && QK_K != 64
|
||||||
|
|
||||||
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
||||||
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue