Compare commits

...
Sign in to create a new pull request.

24 commits

Author SHA1 Message Date
Iwan Kawrakow
31cecc8734 iq3_s_mult_shuffle: use lookup table on Metal
~4% faster TG and ~2% faster PP that way.
2024-03-05 10:19:44 +02:00
Iwan Kawrakow
93034df760 iq3_s_mult_shuffle: use lookup table on CUDA
~4% faster TG that way.
2024-03-05 10:06:07 +02:00
Iwan Kawrakow
6d15da1ec0 iq3_s_mult_shuffle: use new multiplier and cleanup 2024-03-05 08:36:57 +02:00
Iwan Kawrakow
b1d753be34 iq3_s_mult: remove SLOW_MULT option 2024-03-05 08:23:37 +02:00
Iwan Kawrakow
a6a263b919 iq3_s_mult_shuffle: works on ARM_NEON and Metal 2024-03-04 20:10:36 +02:00
Iwan Kawrakow
b587482287 iq3_s_mult_shuffle: mult + shuffle based codebook 2024-03-04 19:43:22 +02:00
Iwan Kawrakow
b48bf8b411 iq3_s_mult: scalar dot product 2024-03-04 08:22:03 +02:00
Iwan Kawrakow
f2c2bd6b26 iq3_s_mult: also CUDA 2024-03-03 19:12:05 +02:00
Iwan Kawrakow
e5e72562c5 iq3_s_mult: back to blocks of 32 2024-03-03 18:50:26 +02:00
Iwan Kawrakow
f4cb4eac45 iq3_s_mult: play with blocks of 16
This brings the bpw to 3.5625. We come close but
don't quite match lookup with 3.4375 bpw (blocks of 32)
2024-03-03 16:43:00 +02:00
Iwan Kawrakow
dbe98dfe70 iq3_s_mult: another alternative multiplier 2024-03-03 13:13:52 +02:00
Iwan Kawrakow
8b713a987e iq3s_mult: quantization tuning 2024-03-03 11:32:53 +02:00
Iwan Kawrakow
5b9c8785fa iq3s_mult: ARM and Metal 2024-03-03 11:30:01 +02:00
Iwan Kawrakow
b6402fa757 iq3_s_mult: ifdef'd slow / fast versions 2024-03-03 10:43:53 +02:00
Iwan Kawrakow
726aed307a iq3_s_mult: alternative multiplier / bit twidling 2024-03-03 08:51:28 +02:00
Iwan Kawrakow
fe3c20b251 iq3_s_mult: quantization tuning 2024-03-03 07:51:20 +02:00
Iwan Kawrakow
3000e0ac9e iq3_s_mult: Metal works - slower than lookup 2024-03-03 06:41:58 +02:00
Iwan Kawrakow
bf90920fb2 iq3_s_mult: ARM_NEON works - 13 t/s 2024-03-02 19:17:27 +02:00
Iwan Kawrakow
0fe9cd488f WIP 2024-03-02 17:56:16 +02:00
Iwan Kawrakow
e43e81a5d7 WIP 2024-03-01 18:48:08 +02:00
Iwan Kawrakow
160acecaba iq3_s_multiplier: CUDA and AVX2 works
CUDA is 153.8 t/s, so faster than lookup table (151 t/s) and Q3_K_S (145 t/s).
AVX2 on Ryzen-5975WX is 13.7 t/s, so faster than lookup (12.7 t/s), but
slower than Q3_K_S (15.5 t/s).
2024-03-01 13:44:06 +02:00
Iwan Kawrakow
4c21c826e1 WIP 2024-03-01 13:28:20 +02:00
Iwan Kawrakow
1cc7cb2b46 iq3_s(multiplier): use SIMD also in dequantize 2024-03-01 12:02:39 +02:00
Iwan Kawrakow
9c752ff0d3 Trying IQ3_S without a lookup table 2024-03-01 11:52:17 +02:00
4 changed files with 482 additions and 323 deletions

View file

@ -544,14 +544,15 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define QR3_XS 8
#define QI3_XS (QK_K / (4*QR3_XS))
#define IQ3S_BLOCK_SIZE 32
typedef struct {
half d;
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/32];
uint8_t signs[QK_K/8];
uint8_t scales[QK_K/64];
uint8_t scales[QK_K/(2*IQ3S_BLOCK_SIZE)];
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + QK_K/(2*IQ3S_BLOCK_SIZE), "wrong iq3_s block size/padding");
#define QR1_S 8
#define QI1_S (QK_K / (4*QR1_S))
@ -2008,71 +2009,71 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const __device__ uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
static const __device__ uint32_t iq3s_grid[512] = {
0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307,
0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101,
0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07,
0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701,
0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507,
0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301,
0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107,
0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01,
0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907,
0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501,
0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307,
0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101,
0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07,
0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901,
0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507,
0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301,
0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107,
0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01,
0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907,
0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701,
0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507,
0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301,
0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107,
0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01,
0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707,
0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501,
0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307,
0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101,
0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07,
0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901,
0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507,
0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301,
0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107,
0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01,
0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907,
0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701,
0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507,
0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301,
0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107,
0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01,
0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707,
0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501,
0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307,
0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101,
0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07,
0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701,
0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507,
0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301,
0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107,
0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01,
0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907,
0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501,
0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307,
0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101,
0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07,
0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901,
0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707,
0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501,
0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307,
0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101,
0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07,
0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701,
0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507,
0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301,
};
@ -2370,6 +2371,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
}
// On CUDA it is fuster to use a lookup table instead of directly computing using these
//#define IQ3S_MULTIPLIER 518559
//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@ -2382,14 +2387,22 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
//int32_t aux32[2];
//const int8_t * grid = (const int8_t *)aux32;
const int is = (32*ib + 8*il)/IQ3S_BLOCK_SIZE;
const float d = (float)x[i].d * (1 + 2*((x[i].scales[is/2] >> 4*(is%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
//aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
//aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
//for (int j = 0; j < 8; ++j) {
// y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
//}
#else
assert(false);
#endif
@ -5189,30 +5202,50 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#endif
}
// TODO: don't use lookup table for signs
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
#if QK_K == 256
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
//uint32_t aux32[2];
//uint8_t * aux8 = (uint8_t *)aux32;
const int ib32 = iqs;
const uint8_t * qs = bq2->qs + 8*ib32;
const int8_t * q8 = bq8_1[ib32].qs;
#if IQ3S_BLOCK_SIZE == 32
int sumi = 0;
#else
int sumi[2] = {0, 0};
#endif
for (int l = 0; l < 4; ++l) {
const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
//aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
//aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
//for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]];
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
//const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
//const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
const int grid_l = __vsub4(iq3s_grid[qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)] ^ signs0, signs0);
const int grid_h = __vsub4(iq3s_grid[qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)] ^ signs1, signs1);
#if IQ3S_BLOCK_SIZE == 32
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
#else
sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]);
sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]);
#endif
q8 += 8;
}
const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
#if IQ3S_BLOCK_SIZE == 32
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
return d * sumi;
#else
int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf);
int ls2 = 1 + 2*(bq2->scales[ib32] >> 4);
return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]);
#endif
#else
assert(false);
return 0.f;

View file

@ -2547,6 +2547,10 @@ typedef struct {
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
// When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table
//#define IQ3S_MULTIPLIER 518559
//constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
typedef struct {
half d;
uint8_t qs[QK_K/8];
@ -4083,71 +4087,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
constexpr constant static uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
constexpr constant static uint32_t iq3s_grid[512] = {
0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307,
0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101,
0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07,
0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701,
0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507,
0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301,
0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107,
0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01,
0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907,
0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501,
0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307,
0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101,
0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07,
0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901,
0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507,
0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301,
0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107,
0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01,
0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907,
0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701,
0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507,
0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301,
0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107,
0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01,
0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707,
0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501,
0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307,
0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101,
0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07,
0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901,
0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507,
0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301,
0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107,
0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01,
0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907,
0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701,
0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507,
0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301,
0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107,
0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01,
0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707,
0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501,
0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307,
0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101,
0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07,
0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701,
0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507,
0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301,
0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107,
0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01,
0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907,
0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501,
0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307,
0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101,
0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07,
0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901,
0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707,
0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501,
0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307,
0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101,
0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07,
0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701,
0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507,
0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301,
};
#define NGRID_IQ1S 512
@ -4759,12 +4763,24 @@ void kernel_mul_mv_iq3_s_f32_impl(
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
for (int i = 0; i < nval; ++i) {
values[pos + i] = iq3s_grid[pos + i];
}
//uint32_t aux32;
//thread int8_t * q = (thread int8_t *)&aux32;
//for (int i = 0; i < nval; ++i) {
// aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
// for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
// values[pos + i] = aux32;
//}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
const int ix = tiisg;
//uint32_t aux32[2];
//thread const int8_t * grid = (thread const int8_t *)aux32;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
@ -4786,12 +4802,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
// This is slower than pre-computing the grid in shared memory and loading from there
//aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
//aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
//for (int j = 0; j < 4; ++j) {
// sum[0] += yl[8*l + j + 0] * iq3s_values[grid[j+0]] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
// sum[1] += yl[8*l + j + 4] * iq3s_values[grid[j+4]] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
//}
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
@ -4812,7 +4835,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@ -5703,19 +5726,33 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
//uint32_t aux32[2];
//thread const int8_t * grid = (thread const int8_t *)aux32;
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
//for (int i = 0; i < 4; ++i) {
// reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
// reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
//}
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
//for (int i = 0; i < 4; ++i) {
// reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
// reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
//}
}
template <typename type4x4>

View file

@ -3789,73 +3789,6 @@ static const uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
};
#define NGRID_IQ2XXS 512
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@ -4121,10 +4054,20 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
// ====================== 3.3125 bpw (de)-quantization
#define IQ3S_MULTIPLIER 518559
#define IQ3S_BITS 3
static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
uint32_t aux32[2];
const int8_t * grid = (const int8_t *)aux32;
float db[64/IQ3S_BLOCK_SIZE];
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
@ -4133,25 +4076,32 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
const uint8_t * signs = x[i].signs;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f;
#if IQ3S_BLOCK_SIZE == 16
db[0] = d * (1 + 2*(x[i].scales[ib32+0] & 0xf));
db[1] = d * (1 + 2*(x[i].scales[ib32+0] >> 4));
db[2] = d * (1 + 2*(x[i].scales[ib32+1] & 0xf));
db[3] = d * (1 + 2*(x[i].scales[ib32+1] >> 4));
#else
db[0] = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
db[1] = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
#endif
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
const float dl = db[8*l/IQ3S_BLOCK_SIZE];
aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
for (int j = 0; j < 8; ++j) {
y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
qs += 8;
signs += 4;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
const float dl = db[(8*l+32)/IQ3S_BLOCK_SIZE];
aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
for (int j = 0; j < 8; ++j) {
y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
@ -9982,6 +9932,15 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
#endif
}
#ifdef __AVX2__
static inline __m256i shift_left_epi16(__m256i a, __m256i count) {
const __m256i mask = _mm256_set1_epi32(0xffff0000);
const __m256i lo_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count));
const __m256i hi_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16));
return _mm256_blend_epi16(lo_half, hi_half, 0xaa);
}
#endif
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
@ -9990,6 +9949,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
UNUSED(by);
UNUSED(bs);
GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented");
const block_iq3_s * restrict x = vx;
const block_q8_K * restrict y = vy;
@ -10003,8 +9964,17 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
const int8x16_t shuff = vld1q_s8((const int8_t *)iq3s_values);
const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER);
const int16x8_t idx_shift = vld1q_s16(k_shift);
const uint16x8_t idx_mask1 = vdupq_n_u16(256);
const uint32x4_t idx_mask2 = vdupq_n_u32(0x0f0f0f0f);
const int8x16_t m1 = vdupq_n_s8(1);
uint8x16x2_t vs;
ggml_int8x16x4_t q3s;
@ -10020,44 +9990,44 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
int sumi1 = 0, sumi2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
qs += 16;
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
const uint16x8_t idx_1 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), idx_shift), idx_mask1),
vmovl_u8(vget_low_u8(idx_l)));
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
vmovl_u8(vget_high_u8(idx_l)));
q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)));
q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)));
q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)));
q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)));
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[0]);
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[1]);
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
signs += 4;
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
signs += 4;
}
sumf += d*(sumi1 + sumi2);
}
*s = 0.25f * sumf;
*s = sumf;
#elif defined(__AVX2__)
@ -10071,6 +10041,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
const __m128i shuffle128 = _mm_loadu_si128((const __m128i *)iq3s_values);
const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128);
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
const __m256i m100 = _mm256_set1_epi32(0x0100);
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
@ -10084,24 +10060,22 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
qs += 8;
const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
qs += 8;
const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32(
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)),
_mm256_setzero_si256());
const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32(
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)),
_mm256_setzero_si256());
const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1);
const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2);
const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15));
const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15));
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
@ -10129,10 +10103,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
}
*s = 0.25f * hsum_float_8(accumf);
*s = hsum_float_8(accumf);
#else
uint32_t aux32[2];
const uint8_t * grid = (const uint8_t *)aux32;
float sumf = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
@ -10146,11 +10123,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
int32_t sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
@ -10159,11 +10135,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
bsum += sumi * ls1;
sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f;
aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
sumi += q8[j] * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
@ -10173,7 +10148,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
}
sumf += d * bsum;
}
*s = 0.25f * sumf;
*s = sumf;
#endif
}
@ -11311,6 +11286,111 @@ static int iq3_compare_func(const void * left, const void * right) {
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
}
static void iq3xs_init_grid512(void) {
const int kmap_size = 1 << IQ3S_BITS*4;
const int grid_size = 512;
const int nwant = 3;
const int gindex = iq3_data_index(512);
const uint8_t kmask = (1 << IQ3S_BITS) - 1;
uint32_t * kgrid_q3xs;
int * kmap_q3xs;
uint16_t * kneighbors_q3xs;
printf("================================================================= %s(grid_size = %d, map_size = %d)\n",
__func__, grid_size, kmap_size);
uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
uint32_t aux32;
const uint8_t * q4 = (const uint8_t *)&aux32;
for (int k = 0; k < grid_size; ++k) {
int8_t * pos = (int8_t *)(the_grid + k);
aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
for (int i = 0; i < 4; ++i) {
pos[i] = iq3s_values[q4[i]];
}
}
kgrid_q3xs = the_grid;
iq3_data[gindex].grid = the_grid;
kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
iq3_data[gindex].map = kmap_q3xs;
for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
for (int i = 0; i < grid_size; ++i) {
aux32 = kgrid_q3xs[i];
uint16_t index = 0;
for (int k=0; k<4; ++k) {
uint16_t q = (q4[k] - 1)/2;
index |= (q << IQ3S_BITS*k);
}
kmap_q3xs[index] = i;
}
int8_t pos[4];
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
int num_neighbors = 0, num_not_in_map = 0;
for (int i = 0; i < kmap_size; ++i) {
if (kmap_q3xs[i] >= 0) continue;
++num_not_in_map;
for (int k = 0; k < 4; ++k) {
int l = (i >> IQ3S_BITS*k) & kmask;
pos[k] = 2*l + 1;
}
for (int j = 0; j < grid_size; ++j) {
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
int d2 = 0;
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
dist2[2*j+0] = d2;
dist2[2*j+1] = j;
}
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
int n = 0; int d2 = dist2[0];
int nhave = 1;
for (int j = 0; j < grid_size; ++j) {
if (dist2[2*j] > d2) {
if (nhave == nwant) break;
d2 = dist2[2*j];
++nhave;
}
++n;
}
num_neighbors += n;
}
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
iq3_data[gindex].neighbours = kneighbors_q3xs;
int counter = 0;
for (int i = 0; i < kmap_size; ++i) {
if (kmap_q3xs[i] >= 0) continue;
for (int k = 0; k < 4; ++k) {
int l = (i >> IQ3S_BITS*k) & kmask;
pos[k] = 2*l + 1;
}
for (int j = 0; j < grid_size; ++j) {
const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
int d2 = 0;
for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
dist2[2*j+0] = d2;
dist2[2*j+1] = j;
}
qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
kmap_q3xs[i] = -(counter + 1);
int d2 = dist2[0];
uint16_t * start = &kneighbors_q3xs[counter++];
int n = 0, nhave = 1;
for (int j = 0; j < grid_size; ++j) {
if (dist2[2*j] > d2) {
if (nhave == nwant) break;
d2 = dist2[2*j];
++nhave;
}
kneighbors_q3xs[counter++] = dist2[2*j+1];
++n;
}
*start = n;
}
free(dist2);
}
void iq3xs_init_impl(int grid_size) {
const int gindex = iq3_data_index(grid_size);
if (iq3_data[gindex].grid) {
@ -11334,44 +11414,49 @@ void iq3xs_init_impl(int grid_size) {
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
};
static const uint16_t kgrid_512[512] = {
0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34,
37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77,
80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142,
145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210,
217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288,
291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393,
395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514,
516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576,
577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653,
655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727,
728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833,
840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977,
989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047,
1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103,
1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199,
1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296,
1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415,
1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561,
1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648,
1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761,
1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877,
1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068,
2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177,
2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269,
2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520,
2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634,
2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805,
2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083,
3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276,
3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591,
3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729,
3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032,
};
if (grid_size == 512) {
iq3xs_init_grid512();
return;
}
// static const uint16_t kgrid_512[512] = {
// 0, 170, 356, 23, 137, 324, 54, 225, 411, 78, 192, 435, 109, 280, 466, 133,
// 312, 42, 164, 343, 9, 708, 886, 545, 723, 910, 576, 819, 933, 600, 778, 965,
// 688, 874, 1052, 1175, 1345, 1084, 1262, 1441, 1107, 1230, 1408, 1139, 1317, 1496, 1162, 1285,
//1584, 1698, 1884, 1551, 1729, 1908, 1582, 1753, 1939, 1606, 1848, 1963, 1637, 1808, 2506, 2181,
//2416, 2082, 2204, 2383, 2049, 2292, 2470, 2073, 2251, 2438, 2160, 2859, 3037, 2704, 2818, 2621,
//2728, 2914, 2580, 2767, 2881, 2612, 2790, 2969, 2635, 3334, 3504, 3179, 3357, 3536, 3202, 3445,
//3112, 3226, 3412, 3079, 3321, 3500, 3622, 3793, 3979, 3654, 3888, 4067, 3677, 264, 2, 181,
// 360, 26, 212, 327, 57, 236, 414, 81, 259, 446, 104, 291, 469, 136, 322, 53,
// 160, 346, 524, 711, 945, 556, 734, 913, 579, 830, 1000, 611, 789, 512, 698, 1389,
//1056, 1170, 1356, 1031, 1265, 1444, 1118, 1289, 1411, 1142, 1320, 1499, 1173, 1856, 1594, 1709,
//1888, 1554, 1740, 1983, 1577, 1764, 1942, 1609, 1795, 2038, 2144, 2331, 2061, 2176, 2418, 2093,
//2200, 2386, 2052, 2303, 2473, 2148, 2262, 2441, 2627, 2870, 3040, 2707, 2893, 2560, 2738, 2917,
//2584, 2762, 2948, 2615, 2793, 2972, 3158, 3329, 3579, 3182, 3360, 3091, 3213, 3392, 3122, 3237,
//3416, 3082, 3268, 4023, 3681, 3804, 3982, 3649, 3891, 4078, 152, 275, 5, 184, 362, 37,
// 208, 394, 4, 247, 417, 92, 270, 449, 115, 294, 24, 139, 325, 48, 682, 861,
// 528, 706, 956, 623, 737, 916, 590, 769, 1011, 678, 792, 523, 1157, 1392, 1066, 1245,
//1360, 1026, 1268, 1455, 1113, 1300, 1478, 1145, 1323, 1574, 1680, 1867, 1541, 1712, 1890, 1565,
//1736, 1922, 1652, 1775, 1945, 1620, 1798, 2553, 2219, 2334, 2064, 2179, 2429, 2088, 2274, 2389,
//2056, 2242, 2484, 2151, 2329, 2956, 2630, 2865, 3051, 2718, 2896, 2563, 2749, 2920, 2594, 2773,
//2944, 2682, 3308, 3495, 3153, 3340, 3526, 3249, 3363, 3102, 3208, 3395, 3125, 3304, 3418, 3093,
//3776, 4026, 3692, 3879, 3985, 3660, 3902, 489, 163, 342, 8, 131, 373, 32, 218, 397,
// 64, 242, 428, 95, 273, 452, 190, 297, 35, 150, 328, 515, 757, 864, 530, 717,
// 896, 626, 804, 927, 585, 772, 1014, 681, 859, 1046, 1152, 1403, 1069, 1248, 1426, 1037,
//1216, 1458, 1124, 1311, 1481, 1156, 1846, 1569, 1691, 1870, 1536, 1779, 1901, 1560, 1746, 1925,
//1656, 1834, 1956, 1623, 2313, 2500, 2230, 2401, 2075, 2190, 2368, 2099, 2277, 2456, 2058, 2245,
//2480, 2666, 2844, 3031, 2625, 2876, 2606, 2721, 2899, 2574, 2752, 2931, 2597, 2776, 2954, 3141,
//3376, 3498, 3164, 3343, 3521, 3252, 3438, 3097, 3219, 3398, 3128, 3307, 3493, 3600, 3786, 3973,
//3696, 3874, 4060, 79, 257, 52, 174, 345, 19, 134, 376, 43, 221, 400, 66, 253,
// 424, 98, 276, 463, 129, 372, 38, 153, 843, 518, 752, 939, 541, 720, 898, 637,
// 808, 994, 596, 775, 569, 1196, 1382, 1041, 1163, 1350, 1072, 1251, 1437, 1096, 1218, 1461,
//1128, 1306, 1492, 1671, 1849, 1580, 1702, 1873, 1547, 1790, 1960, 1571, 1749, 1928, 1602, 1845,
//2016, 2138, 2316, 2055, 2225, 2412, 2078, 2193, 2371, 2110, 2280, 2467, 2133, 2248, 2946, 2677,
// };
const int kmap_size = 4096;
const int nwant = grid_size == 256 ? 2 : 3;
const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
//const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
const uint16_t * kgrid = kgrid_256;
uint32_t * kgrid_q3xs;
int * kmap_q3xs;
uint16_t * kneighbors_q3xs;
@ -11773,6 +11858,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
for (int ibl = 0; ibl < nbl; ++ibl) {
//float block_mse = 0;
memset(&y[ibl], 0, sizeof(block_iq3_s));
y[ibl].d = GGML_FP32_TO_FP16(0.f);
@ -11812,9 +11899,10 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
scales[ib] = 0;
continue;
}
for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
float best = 0;
float scale = max/(2*kMaxQ-1);
for (int is = -15; is <= 15; ++is) {
for (int is = -9; is <= 9; ++is) {
float id = (2*kMaxQ-1+is*0.2f)/max;
float this_scale = 1/id;
for (int k = 0; k < bs4; ++k) {
@ -11823,7 +11911,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
}
uint16_t u = 0;
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
for (int i = 0; i < 4; ++i) u |= Laux[4*k+i] << IQ3S_BITS*i;
int grid_index = kmap_q3xs[u];
is_on_grid_aux[k] = true;
if (grid_index < 0) {
@ -11850,12 +11938,12 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
if (n_not_ongrid > 0 && scale > 0) {
float id = 1/scale;
for (int k = 0; k < bs4; ++k) {
if (is_on_grid[k]) continue;
//if (is_on_grid[k]) continue;
uint16_t u = 0;
for (int i = 0; i < 4; ++i) {
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
l = MAX(0, MIN(kMaxQ-1, l));
u |= (l << 3*i);
u |= l << IQ3S_BITS*i;
}
int grid_index = kmap_q3xs[u];
if (grid_index < 0) {
@ -11882,7 +11970,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
for (int k = 0; k < bs4; ++k) {
uint16_t u = 0;
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
for (int i = 0; i < 4; ++i) u |= L[4*k+i] << IQ3S_BITS*i;
int grid_index = kmap_q3xs[u];
if (grid_index < 0) {
printf("Oops: found point %u not on grid:", u);
@ -11899,6 +11987,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
GGML_ASSERT(scale >= 0);
scales[ib] = scale;
max_scale = MAX(max_scale, scale);
}
if (!max_scale) {
@ -11906,7 +11995,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
y[ibl].d = GGML_FP32_TO_FP16(d * 1.030f);
float id = 1/d;
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
@ -11919,7 +12008,6 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
}
}
#define IQ3S_BLOCK_SIZE 32
size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);

View file

@ -201,10 +201,11 @@ typedef struct {
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
// 3.4375 bpw
#define IQ3S_BLOCK_SIZE 32
#if QK_K == 64
#define IQ3S_N_SCALE 2
#else
#define IQ3S_N_SCALE QK_K/64
#define IQ3S_N_SCALE QK_K/(2*IQ3S_BLOCK_SIZE)
#endif
typedef struct {
ggml_fp16_t d;