ggml : SOTA 2-bit quants (add IQ2_XS) (#4856)

* iq2_xs: basics

* iq2_xs: this should have been in the basics

* iq2_xs: CUDA and scalar CPU works

* iq2_xs: WIP Metal

* iq2_xs: Metal now works

* iq2_xs: working, but dog slow, ARM_NEON dot product

* iq2_xs: better ARM_NEON dot product

We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.

* iq2_xs: AVX2 dot product - 19.5 t/s

* iq2_xs: faster AVX2 dit product

21.4 t/s for TG-128, 59.2 t/s for PP-512.
The latter is 2x compared to the previous version.

* iq2_xs: had forgotten to delete iq2-data.h

* Add llama enum for IQ2_XS

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2024-01-11 20:39:39 +01:00 committed by GitHub
parent 3ba5b8ca8e
commit 49662cbed3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1038 additions and 28 deletions

View file

@ -2452,6 +2452,13 @@ typedef struct {
} block_iq2_xxs;
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
typedef struct {
half d;
uint16_t qs[QK_K/8];
uint8_t scales[QK_K/32];
} block_iq2_xs;
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
//====================================== dot products =========================
void kernel_mul_mv_q2_K_f32_impl(
@ -3476,7 +3483,7 @@ kernel void kernel_mul_mv_q6_K_f32(
// ======================= "True" 2-bit
constexpr constant static uint64_t kgrid_iq2xxs[256] = {
constexpr constant static uint64_t iq2xxs_grid[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
@ -3543,6 +3550,137 @@ constexpr constant static uint64_t kgrid_iq2xxs[256] = {
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
};
constexpr constant static uint64_t iq2xs_grid[512] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
constexpr constant static uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@ -3600,7 +3738,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
@ -3689,6 +3827,149 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq2_xs_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne10,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
#if QK_K == 256
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
const int ibl = ib32 / (QK_K / 32);
const int ib = ib32 % (QK_K / 32);
device const block_iq2_xs * xr = x + ibl;
device const uint16_t * q2 = xr->qs + 4 * ib;
device const uint8_t * sc = xr->scales + ib;
device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
const uint8_t ls1 = sc[0] & 0xf;
const uint8_t ls2 = sc[0] >> 4;
const float d1 = db * (0.5f + ls1);
const float d2 = db * (0.5f + ls2);
float sum1 = 0, sum2 = 0;
for (int l = 0; l < 2; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
const uint8_t signs = shared_signs[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
for (int l = 2; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
const uint8_t signs = shared_signs[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
sumf[row] += d1 * sum1 + d2 * sum2;
dh += nb*sizeof(block_iq2_xs)/2;
q2 += nb*sizeof(block_iq2_xs)/2;
sc += nb*sizeof(block_iq2_xs);
}
y4 += 32 * 32;
}
#else
// TODO
#endif
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
}
}
}
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
kernel void kernel_mul_mv_iq2_xs_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
@ -3973,18 +4254,39 @@ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
@ -4525,6 +4827,7 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
//
// matrix-matrix multiplication
@ -4562,6 +4865,7 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
//
// indirect matrix-matrix multiplication
@ -4611,6 +4915,7 @@ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mu
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
//
// matrix-vector multiplication
@ -5448,3 +5753,68 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
tiisg,
sgitg);
}
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
kernel void kernel_mul_mv_id_iq2_xs_f32(
device const char * ids,
device const char * src1,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
device const char * src00,
device const char * src01,
device const char * src02,
device const char * src03,
device const char * src04,
device const char * src05,
device const char * src06,
device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
kernel_mul_mv_iq2_xs_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
ne01,
ne02,
ne10,
ne12,
ne0,
ne1,
r2,
r3,
shared_values,
tgpig,
tiisg,
sgitg);
}