diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dfd28df62..785a7a3f9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -544,14 +544,15 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong #define QR3_XS 8 #define QI3_XS (QK_K / (4*QR3_XS)) +#define IQ3S_BLOCK_SIZE 32 typedef struct { half d; uint8_t qs[QK_K/4]; uint8_t qh[QK_K/32]; uint8_t signs[QK_K/8]; - uint8_t scales[QK_K/64]; + uint8_t scales[QK_K/(2*IQ3S_BLOCK_SIZE)]; } block_iq3_s; -static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding"); +static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + QK_K/(2*IQ3S_BLOCK_SIZE), "wrong iq3_s block size/padding"); #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) @@ -2008,71 +2009,71 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const __device__ uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +static const __device__ uint32_t iq3s_grid[512] = { + 0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307, + 0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101, + 0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07, + 0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701, + 0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507, + 0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301, + 0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107, + 0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01, + 0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907, + 0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501, + 0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307, + 0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101, + 0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07, + 0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901, + 0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507, + 0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301, + 0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107, + 0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01, + 0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907, + 0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701, + 0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507, + 0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301, + 0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107, + 0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01, + 0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707, + 0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501, + 0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307, + 0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101, + 0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07, + 0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901, + 0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507, + 0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301, + 0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107, + 0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01, + 0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907, + 0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701, + 0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507, + 0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301, + 0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107, + 0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01, + 0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707, + 0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501, + 0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307, + 0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101, + 0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07, + 0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701, + 0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507, + 0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301, + 0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107, + 0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01, + 0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907, + 0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501, + 0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307, + 0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101, + 0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07, + 0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901, + 0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707, + 0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501, + 0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307, + 0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101, + 0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07, + 0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701, + 0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507, + 0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301, }; @@ -2370,6 +2371,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } +// On CUDA it is fuster to use a lookup table instead of directly computing using these +//#define IQ3S_MULTIPLIER 518559 +//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; + template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2382,14 +2387,22 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); - const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; + //int32_t aux32[2]; + //const int8_t * grid = (const int8_t *)aux32; + const int is = (32*ib + 8*il)/IQ3S_BLOCK_SIZE; + const float d = (float)x[i].d * (1 + 2*((x[i].scales[is/2] >> 4*(is%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); for (int j = 0; j < 4; ++j) { y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } + //aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //for (int j = 0; j < 8; ++j) { + // y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + //} #else assert(false); #endif @@ -5189,30 +5202,50 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } -// TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + //uint32_t aux32[2]; + //uint8_t * aux8 = (uint8_t *)aux32; + const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; const int8_t * q8 = bq8_1[ib32].qs; +#if IQ3S_BLOCK_SIZE == 32 int sumi = 0; +#else + int sumi[2] = {0, 0}; +#endif for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + //aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + //for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); + //const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); + //const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); + const int grid_l = __vsub4(iq3s_grid[qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)] ^ signs0, signs0); + const int grid_h = __vsub4(iq3s_grid[qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)] ^ signs1, signs1); +#if IQ3S_BLOCK_SIZE == 32 sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); +#else + sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]); + sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]); +#endif q8 += 8; } - const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; +#if IQ3S_BLOCK_SIZE == 32 + const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; +#else + int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf); + int ls2 = 1 + 2*(bq2->scales[ib32] >> 4); + return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]); +#endif #else assert(false); return 0.f; diff --git a/ggml-metal.metal b/ggml-metal.metal index 689411903..90fe88884 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2547,6 +2547,10 @@ typedef struct { uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +// When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table +//#define IQ3S_MULTIPLIER 518559 +//constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; + typedef struct { half d; uint8_t qs[QK_K/8]; @@ -4083,71 +4087,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -constexpr constant static uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +constexpr constant static uint32_t iq3s_grid[512] = { + 0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307, + 0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101, + 0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07, + 0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701, + 0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507, + 0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301, + 0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107, + 0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01, + 0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907, + 0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501, + 0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307, + 0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101, + 0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07, + 0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901, + 0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507, + 0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301, + 0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107, + 0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01, + 0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907, + 0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701, + 0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507, + 0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301, + 0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107, + 0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01, + 0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707, + 0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501, + 0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307, + 0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101, + 0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07, + 0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901, + 0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507, + 0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301, + 0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107, + 0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01, + 0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907, + 0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701, + 0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507, + 0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301, + 0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107, + 0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01, + 0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707, + 0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501, + 0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307, + 0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101, + 0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07, + 0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701, + 0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507, + 0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301, + 0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107, + 0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01, + 0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907, + 0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501, + 0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307, + 0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101, + 0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07, + 0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901, + 0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707, + 0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501, + 0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307, + 0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101, + 0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07, + 0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701, + 0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507, + 0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301, }; #define NGRID_IQ1S 512 @@ -4759,12 +4763,24 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) { + values[pos + i] = iq3s_grid[pos + i]; + } + //uint32_t aux32; + //thread int8_t * q = (thread int8_t *)&aux32; + //for (int i = 0; i < nval; ++i) { + // aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; + // for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; + // values[pos + i] = aux32; + //} threadgroup_barrier(mem_flags::mem_threadgroup); } const int ix = tiisg; + //uint32_t aux32[2]; + //thread const int8_t * grid = (thread const int8_t *)aux32; + device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { @@ -4786,12 +4802,19 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; row++) { const float db = dh[0]; - const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf)); + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + // This is slower than pre-computing the grid in shared memory and loading from there + //aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f; + //for (int j = 0; j < 4; ++j) { + // sum[0] += yl[8*l + j + 0] * iq3s_values[grid[j+0]] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + // sum[1] += yl[8*l + j + 4] * iq3s_values[grid[j+4]] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + //} + threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); @@ -4812,7 +4835,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -5703,19 +5726,33 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & device const uint8_t * qs = xb->qs + 8*ib32; device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256))); + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); for (int i = 0; i < 4; ++i) { reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256))); + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); for (int i = 0; i < 4; ++i) { reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } + //uint32_t aux32[2]; + //thread const int8_t * grid = (thread const int8_t *)aux32; + //aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; + //for (int i = 0; i < 4; ++i) { + // reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + // reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + //} + //aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; + //aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; + //for (int i = 0; i < 4; ++i) { + // reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + // reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + //} } template diff --git a/ggml-quants.c b/ggml-quants.c index f73d17ce2..c83b7c775 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3789,73 +3789,6 @@ static const uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, -}; - #define NGRID_IQ2XXS 512 static const uint64_t iq1s_grid[NGRID_IQ2XXS] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, @@ -4121,10 +4054,20 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization +#define IQ3S_MULTIPLIER 518559 +#define IQ3S_BITS 3 + +static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; + void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; + uint32_t aux32[2]; + const int8_t * grid = (const int8_t *)aux32; + + float db[64/IQ3S_BLOCK_SIZE]; + for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); @@ -4133,25 +4076,32 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const uint8_t * signs = x[i].signs; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; - const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; +#if IQ3S_BLOCK_SIZE == 16 + db[0] = d * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + db[1] = d * (1 + 2*(x[i].scales[ib32+0] >> 4)); + db[2] = d * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + db[3] = d * (1 + 2*(x[i].scales[ib32+1] >> 4)); +#else + db[0] = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + db[1] = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + const float dl = db[8*l/IQ3S_BLOCK_SIZE]; + aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + for (int j = 0; j < 8; ++j) { + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + const float dl = db[(8*l+32)/IQ3S_BLOCK_SIZE]; + aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + for (int j = 0; j < 8; ++j) { + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } @@ -9982,6 +9932,15 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +#ifdef __AVX2__ +static inline __m256i shift_left_epi16(__m256i a, __m256i count) { + const __m256i mask = _mm256_set1_epi32(0xffff0000); + const __m256i lo_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); + const __m256i hi_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(lo_half, hi_half, 0xaa); +} +#endif + void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -9990,6 +9949,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(by); UNUSED(bs); + GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented"); + const block_iq3_s * restrict x = vx; const block_q8_K * restrict y = vy; @@ -10003,8 +9964,17 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); const uint8x16_t mask2 = vld1q_u8(k_mask2); + const int8x16_t shuff = vld1q_s8((const int8_t *)iq3s_values); + + const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER); + const int16x8_t idx_shift = vld1q_s16(k_shift); + const uint16x8_t idx_mask1 = vdupq_n_u16(256); + const uint32x4_t idx_mask2 = vdupq_n_u32(0x0f0f0f0f); + const int8x16_t m1 = vdupq_n_s8(1); uint8x16x2_t vs; ggml_int8x16x4_t q3s; @@ -10020,44 +9990,44 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v int sumi1 = 0, sumi2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]}; - const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]}; - const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]}; - const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]}; - qs += 16; + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + const uint16x8_t idx_1 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), idx_shift), idx_mask1), + vmovl_u8(vget_low_u8(idx_l))); + const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1), + vmovl_u8(vget_high_u8(idx_l))); + q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2))); + q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2))); + q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2))); + q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2))); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); - q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[0]); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[1]); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); - signs += 4; - - q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]); const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); + + signs += 4; } sumf += d*(sumi1 + sumi2); } - *s = 0.25f * sumf; + *s = sumf; #elif defined(__AVX2__) @@ -10071,6 +10041,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m128i shuffle128 = _mm_loadu_si128((const __m128i *)iq3s_values); + const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128); + + const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); + const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + const __m256i m100 = _mm256_set1_epi32(0x0100); __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -10084,24 +10060,22 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]); - qs += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]); - qs += 8; + + const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)), + _mm256_setzero_si256()); + const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)), + _mm256_setzero_si256()); + const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1); + const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2); + + const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15)); + const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15)); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -10129,10 +10103,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } - *s = 0.25f * hsum_float_8(accumf); + *s = hsum_float_8(accumf); #else + uint32_t aux32[2]; + const uint8_t * grid = (const uint8_t *)aux32; + float sumf = 0.f; for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -10146,11 +10123,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; int32_t sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } q8 += 8; } @@ -10159,11 +10135,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v bsum += sumi * ls1; sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))) & 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + sumi += q8[j] * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); } q8 += 8; } @@ -10173,7 +10148,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } sumf += d * bsum; } - *s = 0.25f * sumf; + *s = sumf; #endif } @@ -11311,6 +11286,111 @@ static int iq3_compare_func(const void * left, const void * right) { return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; } +static void iq3xs_init_grid512(void) { + + const int kmap_size = 1 << IQ3S_BITS*4; + const int grid_size = 512; + const int nwant = 3; + const int gindex = iq3_data_index(512); + const uint8_t kmask = (1 << IQ3S_BITS) - 1; + + uint32_t * kgrid_q3xs; + int * kmap_q3xs; + uint16_t * kneighbors_q3xs; + + printf("================================================================= %s(grid_size = %d, map_size = %d)\n", + __func__, grid_size, kmap_size); + uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t)); + uint32_t aux32; + const uint8_t * q4 = (const uint8_t *)&aux32; + for (int k = 0; k < grid_size; ++k) { + int8_t * pos = (int8_t *)(the_grid + k); + aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + for (int i = 0; i < 4; ++i) { + pos[i] = iq3s_values[q4[i]]; + } + } + + kgrid_q3xs = the_grid; + iq3_data[gindex].grid = the_grid; + kmap_q3xs = (int *)malloc(kmap_size*sizeof(int)); + iq3_data[gindex].map = kmap_q3xs; + for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1; + for (int i = 0; i < grid_size; ++i) { + aux32 = kgrid_q3xs[i]; + uint16_t index = 0; + for (int k=0; k<4; ++k) { + uint16_t q = (q4[k] - 1)/2; + index |= (q << IQ3S_BITS*k); + } + kmap_q3xs[index] = i; + } + int8_t pos[4]; + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + int num_neighbors = 0, num_not_in_map = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> IQ3S_BITS*k) & kmask; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + num_neighbors += n; + } + printf("%s: %d neighbours in total\n", __func__, num_neighbors); + kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); + iq3_data[gindex].neighbours = kneighbors_q3xs; + int counter = 0; + for (int i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> IQ3S_BITS*k) & kmask; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + kmap_q3xs[i] = -(counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); +} + void iq3xs_init_impl(int grid_size) { const int gindex = iq3_data_index(grid_size); if (iq3_data[gindex].grid) { @@ -11334,44 +11414,49 @@ void iq3xs_init_impl(int grid_size) { 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, }; - static const uint16_t kgrid_512[512] = { - 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34, - 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77, - 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142, - 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210, - 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288, - 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393, - 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514, - 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576, - 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653, - 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727, - 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833, - 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977, - 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047, - 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103, - 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199, - 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296, - 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415, - 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561, - 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648, - 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761, - 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877, - 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068, - 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177, - 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269, - 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520, - 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634, - 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805, - 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083, - 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276, - 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591, - 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729, - 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032, - }; + if (grid_size == 512) { + iq3xs_init_grid512(); + return; + } +// static const uint16_t kgrid_512[512] = { +// 0, 170, 356, 23, 137, 324, 54, 225, 411, 78, 192, 435, 109, 280, 466, 133, +// 312, 42, 164, 343, 9, 708, 886, 545, 723, 910, 576, 819, 933, 600, 778, 965, +// 688, 874, 1052, 1175, 1345, 1084, 1262, 1441, 1107, 1230, 1408, 1139, 1317, 1496, 1162, 1285, +//1584, 1698, 1884, 1551, 1729, 1908, 1582, 1753, 1939, 1606, 1848, 1963, 1637, 1808, 2506, 2181, +//2416, 2082, 2204, 2383, 2049, 2292, 2470, 2073, 2251, 2438, 2160, 2859, 3037, 2704, 2818, 2621, +//2728, 2914, 2580, 2767, 2881, 2612, 2790, 2969, 2635, 3334, 3504, 3179, 3357, 3536, 3202, 3445, +//3112, 3226, 3412, 3079, 3321, 3500, 3622, 3793, 3979, 3654, 3888, 4067, 3677, 264, 2, 181, +// 360, 26, 212, 327, 57, 236, 414, 81, 259, 446, 104, 291, 469, 136, 322, 53, +// 160, 346, 524, 711, 945, 556, 734, 913, 579, 830, 1000, 611, 789, 512, 698, 1389, +//1056, 1170, 1356, 1031, 1265, 1444, 1118, 1289, 1411, 1142, 1320, 1499, 1173, 1856, 1594, 1709, +//1888, 1554, 1740, 1983, 1577, 1764, 1942, 1609, 1795, 2038, 2144, 2331, 2061, 2176, 2418, 2093, +//2200, 2386, 2052, 2303, 2473, 2148, 2262, 2441, 2627, 2870, 3040, 2707, 2893, 2560, 2738, 2917, +//2584, 2762, 2948, 2615, 2793, 2972, 3158, 3329, 3579, 3182, 3360, 3091, 3213, 3392, 3122, 3237, +//3416, 3082, 3268, 4023, 3681, 3804, 3982, 3649, 3891, 4078, 152, 275, 5, 184, 362, 37, +// 208, 394, 4, 247, 417, 92, 270, 449, 115, 294, 24, 139, 325, 48, 682, 861, +// 528, 706, 956, 623, 737, 916, 590, 769, 1011, 678, 792, 523, 1157, 1392, 1066, 1245, +//1360, 1026, 1268, 1455, 1113, 1300, 1478, 1145, 1323, 1574, 1680, 1867, 1541, 1712, 1890, 1565, +//1736, 1922, 1652, 1775, 1945, 1620, 1798, 2553, 2219, 2334, 2064, 2179, 2429, 2088, 2274, 2389, +//2056, 2242, 2484, 2151, 2329, 2956, 2630, 2865, 3051, 2718, 2896, 2563, 2749, 2920, 2594, 2773, +//2944, 2682, 3308, 3495, 3153, 3340, 3526, 3249, 3363, 3102, 3208, 3395, 3125, 3304, 3418, 3093, +//3776, 4026, 3692, 3879, 3985, 3660, 3902, 489, 163, 342, 8, 131, 373, 32, 218, 397, +// 64, 242, 428, 95, 273, 452, 190, 297, 35, 150, 328, 515, 757, 864, 530, 717, +// 896, 626, 804, 927, 585, 772, 1014, 681, 859, 1046, 1152, 1403, 1069, 1248, 1426, 1037, +//1216, 1458, 1124, 1311, 1481, 1156, 1846, 1569, 1691, 1870, 1536, 1779, 1901, 1560, 1746, 1925, +//1656, 1834, 1956, 1623, 2313, 2500, 2230, 2401, 2075, 2190, 2368, 2099, 2277, 2456, 2058, 2245, +//2480, 2666, 2844, 3031, 2625, 2876, 2606, 2721, 2899, 2574, 2752, 2931, 2597, 2776, 2954, 3141, +//3376, 3498, 3164, 3343, 3521, 3252, 3438, 3097, 3219, 3398, 3128, 3307, 3493, 3600, 3786, 3973, +//3696, 3874, 4060, 79, 257, 52, 174, 345, 19, 134, 376, 43, 221, 400, 66, 253, +// 424, 98, 276, 463, 129, 372, 38, 153, 843, 518, 752, 939, 541, 720, 898, 637, +// 808, 994, 596, 775, 569, 1196, 1382, 1041, 1163, 1350, 1072, 1251, 1437, 1096, 1218, 1461, +//1128, 1306, 1492, 1671, 1849, 1580, 1702, 1873, 1547, 1790, 1960, 1571, 1749, 1928, 1602, 1845, +//2016, 2138, 2316, 2055, 2225, 2412, 2078, 2193, 2371, 2110, 2280, 2467, 2133, 2248, 2946, 2677, +// }; const int kmap_size = 4096; const int nwant = grid_size == 256 ? 2 : 3; - const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + //const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512; + const uint16_t * kgrid = kgrid_256; uint32_t * kgrid_q3xs; int * kmap_q3xs; uint16_t * kneighbors_q3xs; @@ -11773,6 +11858,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo for (int ibl = 0; ibl < nbl; ++ibl) { + //float block_mse = 0; + memset(&y[ibl], 0, sizeof(block_iq3_s)); y[ibl].d = GGML_FP32_TO_FP16(0.f); @@ -11812,9 +11899,10 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo scales[ib] = 0; continue; } + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; float best = 0; float scale = max/(2*kMaxQ-1); - for (int is = -15; is <= 15; ++is) { + for (int is = -9; is <= 9; ++is) { float id = (2*kMaxQ-1+is*0.2f)/max; float this_scale = 1/id; for (int k = 0; k < bs4; ++k) { @@ -11823,7 +11911,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); } uint16_t u = 0; - for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + for (int i = 0; i < 4; ++i) u |= Laux[4*k+i] << IQ3S_BITS*i; int grid_index = kmap_q3xs[u]; is_on_grid_aux[k] = true; if (grid_index < 0) { @@ -11850,12 +11938,12 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo if (n_not_ongrid > 0 && scale > 0) { float id = 1/scale; for (int k = 0; k < bs4; ++k) { - if (is_on_grid[k]) continue; + //if (is_on_grid[k]) continue; uint16_t u = 0; for (int i = 0; i < 4; ++i) { int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); l = MAX(0, MIN(kMaxQ-1, l)); - u |= (l << 3*i); + u |= l << IQ3S_BITS*i; } int grid_index = kmap_q3xs[u]; if (grid_index < 0) { @@ -11882,7 +11970,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } for (int k = 0; k < bs4; ++k) { uint16_t u = 0; - for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + for (int i = 0; i < 4; ++i) u |= L[4*k+i] << IQ3S_BITS*i; int grid_index = kmap_q3xs[u]; if (grid_index < 0) { printf("Oops: found point %u not on grid:", u); @@ -11899,6 +11987,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo GGML_ASSERT(scale >= 0); scales[ib] = scale; max_scale = MAX(max_scale, scale); + } if (!max_scale) { @@ -11906,7 +11995,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } float d = max_scale/31; - y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].d = GGML_FP32_TO_FP16(d * 1.030f); float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); @@ -11919,7 +12008,6 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } } -#define IQ3S_BLOCK_SIZE 32 size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; GGML_ASSERT(n_per_row%QK_K == 0); diff --git a/ggml-quants.h b/ggml-quants.h index 2c61134c4..cb7af5961 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -201,10 +201,11 @@ typedef struct { static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); // 3.4375 bpw +#define IQ3S_BLOCK_SIZE 32 #if QK_K == 64 #define IQ3S_N_SCALE 2 #else -#define IQ3S_N_SCALE QK_K/64 +#define IQ3S_N_SCALE QK_K/(2*IQ3S_BLOCK_SIZE) #endif typedef struct { ggml_fp16_t d;